diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index eef4bda714..b962267753 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -21,6 +21,7 @@ ) from torch.distributed.tensor.parallel import ParallelStyle from torch.distributed.tensor.placement_types import Placement +from torch.distributed._functional_collectives import all_to_all_single_autograd # from torch.distributed._functional_collectives import all_to_all_single_autograd @@ -49,8 +50,8 @@ def backward(ctx, grad_y): return grad_x, None, None, None -def all_to_all_single_autograd(x, out_splits, in_splits, group): - return _A2A.apply(x, out_splits, in_splits, group) +# def all_to_all_single_autograd(x, out_splits, in_splits, group): +# return _A2A.apply(x, out_splits, in_splits, group) TOKEN_GROUP_ALIGN_SIZE_M = 8 diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 6d75b4986a..717bcb6abb 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -110,9 +110,6 @@ def parallelize_llama( if job_config.training.compile: apply_compile(model) - # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE - torch._dynamo.config.capture_scalar_outputs = True - dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel @@ -460,3 +457,6 @@ def apply_compile(model: nn.Module): model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") + + # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE + torch._dynamo.config.capture_scalar_outputs = True diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 949f4cf052..7a535ab9ba 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -20,8 +20,8 @@ @dataclass class TransformerModelArgs(BaseModelArgs): dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 + n_layers: int = 2 + n_heads: int = 2 n_kv_heads: int | None = None vocab_size: int = 202048 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 0d1ed83628..c72b67cc68 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -64,7 +64,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 40bd6c2cca..c875f6da48 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -413,10 +413,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + fsdp_state = self.experts._get_fsdp_state() + args, _ = fsdp_state._pre_forward(module=self.experts, args=(routed_input, num_tokens_per_expert), kwargs={}) + routed_input, num_tokens_per_expert = args + # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) if not self.score_before_experts: + assert False routed_output = ( routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) @@ -428,6 +433,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: out = torch.zeros_like(x) + routed_output = fsdp_state._post_forward(module=self.experts, input=(routed_input, num_tokens_per_expert), output=routed_output) + out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output )