Skip to content

Commit 07096f4

Browse files
cleaner workaround for moe + compile + ac graph break issue
1 parent 58c5455 commit 07096f4

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -513,14 +513,38 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
513513
# torch._dynamo.config.capture_scalar_outputs = True
514514
for layer_id, transformer_block in model.layers.named_children():
515515
# TODO: remove when torch.compile supports fullgraph=True for MoE
516-
fullgraph = True
517516
if transformer_block.moe_enabled:
518-
fullgraph = False
519-
transformer_block = torch.compile(
520-
transformer_block,
521-
backend=compile_config.backend,
522-
fullgraph=fullgraph,
523-
)
517+
transformer_block.moe.experts = torch.compile(
518+
transformer_block.moe.experts,
519+
backend=compile_config.backend,
520+
fullgraph=True,
521+
)
522+
transformer_block.moe.shared_experts = torch.compile(
523+
transformer_block.moe.shared_experts,
524+
backend=compile_config.backend,
525+
fullgraph=True,
526+
)
527+
# transformer_block.attention = torch.compile(
528+
# transformer_block.attention,
529+
# backend=compile_config.backend,
530+
# fullgraph=True,
531+
# )
532+
# transformer_block.attention_norm = torch.compile(
533+
# transformer_block.attention_norm,
534+
# backend=compile_config.backend,
535+
# fullgraph=True,
536+
# )
537+
# transformer_block.ffn_norm = torch.compile(
538+
# transformer_block.ffn_norm,
539+
# backend=compile_config.backend,
540+
# fullgraph=True,
541+
# )
542+
else:
543+
transformer_block = torch.compile(
544+
transformer_block,
545+
backend=compile_config.backend,
546+
fullgraph=True,
547+
)
524548
model.layers.register_module(layer_id, transformer_block)
525549

526550
logger.info("Compiling each TransformerBlock with torch.compile")

0 commit comments

Comments
 (0)