From ece90a23ff3ce1c14d98b0de96a24876e2b8b49f Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Fri, 12 Sep 2025 01:00:58 -0700 Subject: [PATCH] [WIP] async_tp shape mismatch in rs+mm repro --- .../auto_parallel/parallelize_llama.py | 105 +++++++++++++++++- torchtitan/models/llama3/__init__.py | 2 +- 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6648f29ab8..2a9eae328d 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -10,9 +10,8 @@ from autoparallel.api import AutoParallel -from torch.distributed import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy -from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims @@ -20,6 +19,89 @@ from torchtitan.tools.logging import logger +def group_mm_nodes_with_its_gradients(nodes): + fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta] + bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta] + assert len(fwd_nodes) * 2 == len(bwd_nodes) + res = {} + for fwd_node in fwd_nodes: + o = [] + for bwd_node in bwd_nodes: + if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]: + o.append(bwd_node) + assert len(o) == 2 + res[fwd_node] = o + return res + + +def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False): + # out = x @ w - S(0)R, RS(1) -> S(0)S(1) + # g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0) + # g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P + + add_node_constraint = autop.sharding_optimizer.add_node_constraint + fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes) + fwd_nodes = list(fwd_bwd_groups.keys()) + dim1 = 0 if feat_dim == 1 else 1 + dim2 = 1 if feat_dim == 1 else 0 + # assume there are 7 mm nodes per transformer block + # skip last mm as it's the final projection layer + assert ( + len(fwd_nodes) - 1 + ) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}" + for block in range(0, len(fwd_nodes) - 1, 7): + fwd_nodes_block = fwd_nodes[block : block + 7] + # force the first 3 mm nodes to be S(0)S(1) + the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6] + for n in the_nodes: + add_node_constraint(n, (Shard(0), Shard(feat_dim))) + add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate())) + add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1))) + + if bwd_constraint: + bwd_nodes = fwd_bwd_groups[n] + # first is g_w, second is g_x + add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1))) + add_node_constraint(bwd_nodes[1], (Shard(0), Partial())) + + # add reduction to finish TP, yielding S(0)P + the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7] + for n in the_nodes: + add_node_constraint(n, (Shard(0), Partial())) + add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim))) + add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0))) + + if bwd_constraint: + bwd_nodes = fwd_bwd_groups[n] + # first is g_w, second is g_x + add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2))) + add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim))) + + +def add_tp_constraints(autop): + mm_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + einsum_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + assert (len(mm_nodes) > 0) ^ ( + len(einsum_nodes) > 0 + ), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}" + feat_dim = 1 if len(mm_nodes) > 0 else 2 + tgt_nodes = mm_nodes + einsum_nodes + force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True) + + if einsum_nodes: + # add sequence parallelism if we have einsum nodes + autop.sharding_optimizer.add_node_constraint( + list(tgt_nodes[3].users)[0], (Shard(0), Partial()) + ) + autop.sharding_optimizer.add_node_constraint( + list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1)) + ) + + def parallelize_llama( model, parallel_dims: ParallelDims, @@ -33,6 +115,7 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ world_mesh = parallel_dims.world_mesh + def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -62,6 +145,17 @@ def input_fn(): lambda bucket_idx: 1000 / parallel_dims.tp ) + # XXX MICROPIPELINE + enable_async_tp = True + if enable_async_tp: + mesh = world_mesh + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = True + torch._inductor.config.reorder_for_compute_comm_overlap = False + # XXX--- MICROPIPELINE + # bail out # model = model_fn() # return model @@ -78,6 +172,7 @@ def input_fn(): world_mesh, mp_policy=mp_policy, compile=job_config.compile, + repeated_subgraphs=True, ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) @@ -101,7 +196,8 @@ def input_fn(): ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel ) if loss_parallel_enabled: out_sharding = tuple( @@ -111,6 +207,9 @@ def input_fn(): ) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([out_sharding]) + enable_manual_constraint = True + if enable_manual_constraint: + add_tp_constraints(autop) t0 = time.time() sharding_placement = autop.optimize_placement() t1 = time.time() diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index a34b4463f8..2ecf6f4d65 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -42,7 +42,7 @@ ), "8B": TransformerModelArgs( dim=4096, - n_layers=32, + n_layers=1, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3,