Skip to content

Commit 2a7a148

Browse files
authored
[MoE][compile][full ac] weave torch.compile around the FSDP(GroupedExperts) graph break (#1895)
Stacked PRs: * __->__#1895 --- --- --- This PR changes how we compile MoE layers to work around Compile + AC limitations. When you `AC(Compile(block))` or `Compile(AC(block))` and there is a graph break in `block`, we fall back the entire block to eager. For llama3, we've worked around this problem by addressing all graph breaks. With MoE models particularly dp2ep, we need to wrap`FSDP(block.moe.experts)`, meaning that we will have graph breaks when tracing `block.moe.experts.__call__`, meaning that whenever AC was enabled, the entire block for MoE would fallback to eager: https://gist.github.com/xmfan/50f4de1e89d789cd63a21aca9e600132 (Note in the tlparse, graph 0/1 is empty and it corresponds to the block containing the MoE). The workaround in this PR is to avoid tracing `block.moe.experts.__call__`. This is done by individually wrapping torch.compile on submodules of TransformerBlock. Note that we are leaving some perf on the table as this might exclude some ops in TransformerBlock.forward and MoE.forward. This is an API limitation, as we have no way to acquire those ops while decoupling the wrapper from model code. This workaround will no longer be necessity when either: - We can do Compile + AC with graph breaks - We remove the FSDP graph break This change introduces a small regression to the non-AC configuration. You can see a small perf dip from [before this PR](https://gist.github.com/xmfan/0b32e95980d263cf3f62869fa4d85921) and [after this PR](https://gist.github.com/xmfan/11561b5406b3f92ecd08da94bc5ee4e3). Given that AC is a necessity to run non-toy configurations of these models, I chose to stick to this implementation to make comparisons easier. Validated on DSv3 debug model: - dp2ep, no AC, no compile: https://gist.github.com/xmfan/927f354158ad36f4c5c1ffedde4e4ebe - dp2ep, no AC, compile: https://gist.github.com/xmfan/11561b5406b3f92ecd08da94bc5ee4e3 - before this PR (compile w/ nested graph break): https://gist.github.com/xmfan/0b32e95980d263cf3f62869fa4d85921 - dp2ep, full AC, compile: https://gist.github.com/xmfan/6ed5b48aa51ce0ac2b6bfceb86a0c482 - before this PR (whole moe block in eager): https://gist.github.com/xmfan/50f4de1e89d789cd63a21aca9e600132 - dp2ep, full AC, no compile: https://gist.github.com/xmfan/2308355c2aa4814fe3d12243445555fa - dp2ep, pp, full AC, compile: https://gist.github.com/xmfan/5a1ac23f00abdf93dbcc1539f552e840 - dp2ep, pp, full AC, no compile: https://gist.github.com/xmfan/302cda7191e53ffad5c4dc1e4b8f02de
1 parent 3e084f4 commit 2a7a148

File tree

1 file changed

+65
-9
lines changed

1 file changed

+65
-9
lines changed

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import torch
88
import torch.nn as nn
9+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
10+
CheckpointWrapper,
11+
)
912
from torch.distributed.device_mesh import DeviceMesh
1013
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
1114
from torch.distributed.tensor import Partial, Replicate, Shard
@@ -30,6 +33,7 @@
3033
)
3134
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
3235
from torchtitan.models.llama3.infra.parallelize import apply_ddp
36+
from torchtitan.models.moe import moe as moe_module
3337
from torchtitan.tools.logging import logger
3438

3539

@@ -509,17 +513,69 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
509513
"""
510514
# NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE
511515
# but it is experimental.
512-
# torch._dynamo.config.capture_scalar_outputs = True
516+
torch._dynamo.config.capture_scalar_outputs = True
513517
for layer_id, transformer_block in model.layers.named_children():
514-
# TODO: remove when torch.compile supports fullgraph=True for MoE
515-
fullgraph = True
516518
if transformer_block.moe_enabled:
517-
fullgraph = False
518-
transformer_block = torch.compile(
519-
transformer_block,
520-
backend=compile_config.backend,
521-
fullgraph=fullgraph,
522-
)
519+
# If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break
520+
# So we must weave compile wrappers around those FSDP hooks to
521+
# prevent AC from falling back the whole graph to eager.
522+
# TODO: Fix Compile(AC(graph break))
523+
524+
if isinstance(transformer_block, CheckpointWrapper):
525+
# TODO: Make CheckpointWrapper a transparent wrapper
526+
# unwrap so that .named_children() works
527+
block = transformer_block._checkpoint_wrapped_module
528+
else:
529+
block = transformer_block
530+
531+
for attr_name, submod in block.named_children():
532+
assert getattr(block, attr_name) == getattr(
533+
transformer_block, attr_name
534+
)
535+
536+
if isinstance(submod, moe_module.MoE):
537+
# avoid graph breaking on the GroupedExperts' FSDP hooks
538+
# by wrapping each submod's forward instead of their __call__
539+
moe = submod
540+
for attr_name, submod in moe.named_children():
541+
if attr_name == "experts":
542+
# NOTE: We don't compile token dispatch and token combine due to an issue on B200:
543+
# https://github.com/pytorch/torchtitan/issues/1940
544+
continue
545+
setattr(
546+
moe,
547+
attr_name,
548+
torch.compile(
549+
submod, backend=compile_config.backend, fullgraph=True
550+
),
551+
)
552+
else:
553+
setattr(
554+
block,
555+
attr_name,
556+
torch.compile(
557+
submod, backend=compile_config.backend, fullgraph=True
558+
),
559+
)
560+
561+
else:
562+
# If it's not a MoE layer, there is no FSDP(GroupedExperts)
563+
# So we can compile the whole block
564+
transformer_block = torch.compile(
565+
transformer_block,
566+
backend=compile_config.backend,
567+
fullgraph=True,
568+
)
569+
523570
model.layers.register_module(layer_id, transformer_block)
524571

572+
moe_module._run_experts_grouped_mm = torch.compile(
573+
moe_module._run_experts_grouped_mm,
574+
backend=compile_config.backend,
575+
fullgraph=True,
576+
)
577+
578+
# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
579+
# https://github.com/pytorch/pytorch/issues/166460
580+
525581
logger.info("Compiling each TransformerBlock with torch.compile")

0 commit comments

Comments
 (0)