|
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | import torch.nn as nn |
| 9 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| 10 | + CheckpointWrapper, |
| 11 | +) |
9 | 12 | from torch.distributed.device_mesh import DeviceMesh |
10 | 13 | from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy |
11 | 14 | from torch.distributed.tensor import Partial, Replicate, Shard |
|
30 | 33 | ) |
31 | 34 | from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp |
32 | 35 | from torchtitan.models.llama3.infra.parallelize import apply_ddp |
| 36 | +from torchtitan.models.moe import moe as moe_module |
33 | 37 | from torchtitan.tools.logging import logger |
34 | 38 |
|
35 | 39 |
|
@@ -509,17 +513,69 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): |
509 | 513 | """ |
510 | 514 | # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE |
511 | 515 | # but it is experimental. |
512 | | - # torch._dynamo.config.capture_scalar_outputs = True |
| 516 | + torch._dynamo.config.capture_scalar_outputs = True |
513 | 517 | for layer_id, transformer_block in model.layers.named_children(): |
514 | | - # TODO: remove when torch.compile supports fullgraph=True for MoE |
515 | | - fullgraph = True |
516 | 518 | if transformer_block.moe_enabled: |
517 | | - fullgraph = False |
518 | | - transformer_block = torch.compile( |
519 | | - transformer_block, |
520 | | - backend=compile_config.backend, |
521 | | - fullgraph=fullgraph, |
522 | | - ) |
| 519 | + # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break |
| 520 | + # So we must weave compile wrappers around those FSDP hooks to |
| 521 | + # prevent AC from falling back the whole graph to eager. |
| 522 | + # TODO: Fix Compile(AC(graph break)) |
| 523 | + |
| 524 | + if isinstance(transformer_block, CheckpointWrapper): |
| 525 | + # TODO: Make CheckpointWrapper a transparent wrapper |
| 526 | + # unwrap so that .named_children() works |
| 527 | + block = transformer_block._checkpoint_wrapped_module |
| 528 | + else: |
| 529 | + block = transformer_block |
| 530 | + |
| 531 | + for attr_name, submod in block.named_children(): |
| 532 | + assert getattr(block, attr_name) == getattr( |
| 533 | + transformer_block, attr_name |
| 534 | + ) |
| 535 | + |
| 536 | + if isinstance(submod, moe_module.MoE): |
| 537 | + # avoid graph breaking on the GroupedExperts' FSDP hooks |
| 538 | + # by wrapping each submod's forward instead of their __call__ |
| 539 | + moe = submod |
| 540 | + for attr_name, submod in moe.named_children(): |
| 541 | + if attr_name == "experts": |
| 542 | + # NOTE: We don't compile token dispatch and token combine due to an issue on B200: |
| 543 | + # https://github.com/pytorch/torchtitan/issues/1940 |
| 544 | + continue |
| 545 | + setattr( |
| 546 | + moe, |
| 547 | + attr_name, |
| 548 | + torch.compile( |
| 549 | + submod, backend=compile_config.backend, fullgraph=True |
| 550 | + ), |
| 551 | + ) |
| 552 | + else: |
| 553 | + setattr( |
| 554 | + block, |
| 555 | + attr_name, |
| 556 | + torch.compile( |
| 557 | + submod, backend=compile_config.backend, fullgraph=True |
| 558 | + ), |
| 559 | + ) |
| 560 | + |
| 561 | + else: |
| 562 | + # If it's not a MoE layer, there is no FSDP(GroupedExperts) |
| 563 | + # So we can compile the whole block |
| 564 | + transformer_block = torch.compile( |
| 565 | + transformer_block, |
| 566 | + backend=compile_config.backend, |
| 567 | + fullgraph=True, |
| 568 | + ) |
| 569 | + |
523 | 570 | model.layers.register_module(layer_id, transformer_block) |
524 | 571 |
|
| 572 | + moe_module._run_experts_grouped_mm = torch.compile( |
| 573 | + moe_module._run_experts_grouped_mm, |
| 574 | + backend=compile_config.backend, |
| 575 | + fullgraph=True, |
| 576 | + ) |
| 577 | + |
| 578 | + # NOTE: We don't compile for loop code path due to an issue with unbacked symints: |
| 579 | + # https://github.com/pytorch/pytorch/issues/166460 |
| 580 | + |
525 | 581 | logger.info("Compiling each TransformerBlock with torch.compile") |
0 commit comments