Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.placement_types import Placement
from torch.distributed._functional_collectives import all_to_all_single_autograd


# from torch.distributed._functional_collectives import all_to_all_single_autograd
Expand Down Expand Up @@ -49,8 +50,8 @@ def backward(ctx, grad_y):
return grad_x, None, None, None


def all_to_all_single_autograd(x, out_splits, in_splits, group):
return _A2A.apply(x, out_splits, in_splits, group)
# def all_to_all_single_autograd(x, out_splits, in_splits, group):
# return _A2A.apply(x, out_splits, in_splits, group)


TOKEN_GROUP_ALIGN_SIZE_M = 8
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def parallelize_llama(
if job_config.training.compile:
apply_compile(model)

# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
torch._dynamo.config.capture_scalar_outputs = True

dp_mesh: DeviceMesh | None = None
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
# apply FSDP or HSDP, potentially with Context Parallel
Expand Down Expand Up @@ -460,3 +457,6 @@ def apply_compile(model: nn.Module):
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")

# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
torch._dynamo.config.capture_scalar_outputs = True
4 changes: 2 additions & 2 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
@dataclass
class TransformerModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_layers: int = 2
n_heads: int = 2
n_kv_heads: int | None = None
vocab_size: int = 202048
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)

fsdp_state = self.experts._get_fsdp_state()
args, _ = fsdp_state._pre_forward(module=self.experts, args=(routed_input, num_tokens_per_expert), kwargs={})
routed_input, num_tokens_per_expert = args

# shape (bs*slen*top_k, dim)
routed_output = self.experts(routed_input, num_tokens_per_expert)

if not self.score_before_experts:
assert False
routed_output = (
routed_output.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
Expand All @@ -428,6 +433,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
else:
out = torch.zeros_like(x)

routed_output = fsdp_state._post_forward(module=self.experts, input=(routed_input, num_tokens_per_expert), output=routed_output)

out = out.scatter_add(
dim=0, index=token_indices_experts_sorted, src=routed_output
)
Expand Down
Loading