Skip to content

Commit 02990b0

Browse files
[simplefsdp] fix region ac in zero2-style FSDP (#1970)
After some offline discussion, we've concluded that life would be easier if we can put simplefsdp's checkpoint logic for `reshard_after_forward` to compiler. The ac annotation part is borrowed form AP: [LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69). **Trace and Loss Check** (all with torch.compile enable) reshard_after_fwd = False 1. SAC + llama3 ([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json)) <img width="768" height="115" alt="Screenshot 2025-10-30 at 4 28 59 PM" src="https://github.com/user-attachments/assets/e4e22335-2e3f-46c8-8def-a60d592fee0a" /> <img width="689" height="512" alt="Screenshot 2025-11-05 at 9 02 30 PM" src="https://github.com/user-attachments/assets/40a71316-a457-4e72-9002-cc8beea8f32c" /> 2. Full AC + llama3 [(trace)]() <img width="729" height="105" alt="Screenshot 2025-10-30 at 4 30 53 PM" src="https://github.com/user-attachments/assets/e8d63460-579b-4f0a-8504-851480e5b548" /> <img width="789" height="763" alt="Screenshot 2025-11-05 at 9 11 34 PM" src="https://github.com/user-attachments/assets/1a13d09e-04c4-4db9-99fe-cf10d24bf7f5" /> 3. No AC + llama3 [[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)] <img width="748" height="115" alt="Screenshot 2025-10-30 at 4 32 05 PM" src="https://github.com/user-attachments/assets/20104d24-9d45-4eba-b694-815e133b88d0" /> <img width="800" height="764" alt="Screenshot 2025-11-05 at 9 07 46 PM" src="https://github.com/user-attachments/assets/55b104ce-8ec1-4ed6-95e7-300e96ad55af" /> reshard_after_fwd = True 1. SAC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json)) <img width="795" height="108" alt="Screenshot 2025-10-31 at 11 34 47 AM" src="https://github.com/user-attachments/assets/a3988f72-7e87-4e52-90f9-8bee840cd6f4" /> 2. Full AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json)) <img width="593" height="110" alt="Screenshot 2025-10-31 at 11 38 02 AM" src="https://github.com/user-attachments/assets/5ee61b2b-9600-4af8-9a24-61b3564f93ca" /> 3. No AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json)) <img width="701" height="109" alt="Screenshot 2025-10-31 at 11 43 04 AM" src="https://github.com/user-attachments/assets/576b28f6-dae4-4ff7-b005-57b0cf9ad7cc" />
1 parent f4514ef commit 02990b0

File tree

6 files changed

+167
-97
lines changed

6 files changed

+167
-97
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
[![integration and numerics tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain)
44
[![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284)
55

6-
💡 **Note**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via
6+
💡 **Note 1**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via
77
```bash
88
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
99
```
1010

11+
💡 **Note 2**: Some of SimpleFSDP's functionalities (e.g., reshard_after_forward) is implemented with torch.compile. It is always recommended to open compile (`--compile.enable`) to see desired correct functionality.
12+
1113
This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.
1214

1315
### Run SimpleFSDP Training on Llama3 & DeepSeek_v3

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,25 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Union
7+
from typing import Any
88

99
import torch
10+
import torch._functorch.config as functorch_config
1011

12+
from .reshard_after_forward import annotate_fsdp_all_gather
1113

12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
14+
15+
def get_compile_backend(
16+
backend_name: str, fsdp_reshard_after_forward: bool
17+
) -> callable:
1318
# return the compile backends used in SimpleFSDP training
1419
# Step1: check if backend_name is inside available torch.compile backends
1520
# Step2: check if the backend_name has been registered as a customized backend
1621
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
17-
if backend_name in available_torch_backend:
18-
return backend_name
1922

20-
if backend_name == "aot_eager_autobucketing":
23+
if backend_name in available_torch_backend:
24+
backend = torch._dynamo.lookup_backend(backend_name)
25+
elif backend_name == "aot_eager_autobucketing":
2126
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
2227
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
2328
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
@@ -46,4 +51,22 @@ def aten_autobucketing_reordering_pass(
4651
else:
4752
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4853

49-
return backend
54+
def joint_ac_pass(
55+
gm: torch.fx.GraphModule, example_inputs: Any
56+
) -> torch.fx.GraphModule:
57+
# this pass implements simplefsdp's fsdp_reshard_after_forward behavior
58+
# when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG
59+
# to CheckpointPolicy.MUST_RECOMPUTE.
60+
# when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG
61+
# to CheckpointPolicy.MUST_SAVE.
62+
gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward)
63+
gm.recompile()
64+
return gm
65+
66+
def simple_fsdp_custom_pass(*args, **kwargs):
67+
# the ac pass has to operate in a joint graph before partitioner for ac
68+
# annotation to take into effect.
69+
with functorch_config.patch("joint_custom_pass", joint_ac_pass):
70+
return backend(*args, **kwargs)
71+
72+
return simple_fsdp_custom_pass

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010

1111
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1212
from torchtitan.distributed import ParallelDims
13+
14+
from torchtitan.distributed.activation_checkpoint import apply_ac
1315
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
1416
from torchtitan.models.deepseek_v3.infra.parallelize import (
15-
apply_ac,
1617
apply_moe_ep_tp,
1718
apply_non_moe_tp,
1819
)
1920
from torchtitan.tools.logging import logger
2021

21-
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
22+
from ..backend import get_compile_backend
2223

24+
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
2325

2426
# Adapted from llama4/infra/parallelize.py
2527
def parallelize_deepseekv3(
@@ -91,20 +93,6 @@ def parallelize_deepseekv3(
9193
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
9294
)
9395

94-
match job_config.parallelism.fsdp_reshard_after_forward:
95-
case "always":
96-
reshard_after_forward = True
97-
case "never":
98-
reshard_after_forward = False
99-
case "default":
100-
# For PP, by default do not reshard after forward to avoid per-microbatch
101-
# all-gathers, which can be expensive and non-overlapped
102-
reshard_after_forward = not parallel_dims.pp_enabled
103-
case _:
104-
raise ValueError(
105-
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
106-
)
107-
10896
# apply data parallel
10997
dp_mesh: DeviceMesh | None = None
11098
if (
@@ -155,9 +143,7 @@ def parallelize_deepseekv3(
155143
transformer_block.moe.experts,
156144
dp_mod_ep_mesh,
157145
dp_mode,
158-
ac_mode=job_config.activation_checkpoint.mode,
159146
mp_policy=mp_policy,
160-
reshard_after_forward=reshard_after_forward,
161147
shard_dim=experts_shard_dim,
162148
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
163149
)
@@ -166,9 +152,7 @@ def parallelize_deepseekv3(
166152
model,
167153
dp_mesh,
168154
dp_mode,
169-
ac_mode=job_config.activation_checkpoint.mode,
170155
mp_policy=mp_policy,
171-
reshard_after_forward=reshard_after_forward,
172156
)
173157

174158
logger.info(
@@ -178,6 +162,29 @@ def parallelize_deepseekv3(
178162
if job_config.compile.enable:
179163
torch._inductor.config.reorder_for_peak_memory = False
180164
torch._dynamo.config.capture_scalar_outputs = True
181-
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
165+
166+
match job_config.parallelism.fsdp_reshard_after_forward:
167+
case "always":
168+
fsdp_reshard_after_forward = True
169+
case "never":
170+
fsdp_reshard_after_forward = False
171+
case "default":
172+
# For PP, by default do not reshard after forward to avoid per-microbatch
173+
# all-gathers, which can be expensive and non-overlapped
174+
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
175+
case _:
176+
raise ValueError(
177+
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
178+
)
179+
180+
backend = (
181+
getattr(job_config.compile, "model_backend_override", None)
182+
or job_config.compile.backend
183+
)
184+
model = torch.compile(
185+
model,
186+
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
187+
fullgraph=True,
188+
)
182189

183190
return model

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,41 +112,40 @@ def parallelize_llama(
112112
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
113113
)
114114

115-
match job_config.parallelism.fsdp_reshard_after_forward:
116-
case "always":
117-
reshard_after_forward = True
118-
case "never":
119-
reshard_after_forward = False
120-
case "default":
121-
# For PP, by default do not reshard after forward to avoid per-microbatch
122-
# all-gathers, which can be expensive and non-overlapped
123-
reshard_after_forward = not parallel_dims.pp_enabled
124-
case _:
125-
raise ValueError(
126-
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
127-
)
128-
129115
model = data_parallel(
130116
model,
131117
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
132118
mode=dp_mode,
133-
ac_mode=job_config.activation_checkpoint.mode,
134119
mp_policy=mp_policy,
135-
reshard_after_forward=reshard_after_forward,
136120
)
137121
logger.info(
138122
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
139123
)
140124

141125
if job_config.compile.enable and "model" in job_config.compile.components:
142126
torch._inductor.config.reorder_for_peak_memory = False
127+
128+
match job_config.parallelism.fsdp_reshard_after_forward:
129+
case "always":
130+
fsdp_reshard_after_forward = True
131+
case "never":
132+
fsdp_reshard_after_forward = False
133+
case "default":
134+
# For PP, by default do not reshard after forward to avoid per-microbatch
135+
# all-gathers, which can be expensive and non-overlapped
136+
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
137+
case _:
138+
raise ValueError(
139+
f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
140+
)
141+
143142
backend = (
144143
getattr(job_config.compile, "model_backend_override", None)
145144
or job_config.compile.backend
146145
)
147146
model = torch.compile(
148147
model,
149-
backend=get_compile_backend(backend),
148+
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
150149
fullgraph=True,
151150
)
152151

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch.utils.checkpoint import CheckpointPolicy
9+
10+
11+
def is_graph_input(node: torch.fx.Node) -> bool:
12+
return node.op == "placeholder"
13+
14+
15+
def is_wait_tensor(node: torch.fx.Node) -> bool:
16+
return (
17+
node.op == "call_function"
18+
and node.target == torch.ops._c10d_functional.wait_tensor.default
19+
)
20+
21+
22+
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
23+
return (
24+
node.op == "call_function"
25+
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
26+
)
27+
28+
29+
def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
30+
"""
31+
Returns True if the node is a wait_tensor node that is the result of an all_gather
32+
that can be arbitrarily prefetched, i.e., if all its recursive inputs are
33+
single-input operators that leads to a graph input.
34+
"""
35+
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
36+
n: torch.fx.Node = node.all_input_nodes[0]
37+
while len(n.all_input_nodes) == 1:
38+
if is_graph_input(n.all_input_nodes[0]):
39+
return True
40+
n = n.all_input_nodes[0]
41+
return False
42+
43+
44+
def annotate_fsdp_all_gather(
45+
gm: torch.fx.GraphModule, reshard_after_forward: bool
46+
) -> None:
47+
"""
48+
Force recompute all_gather nodes from simple fsdp in the graph.
49+
This pass should be added in torch._inductor.config.joint_custom_post_pass
50+
"""
51+
graph = gm.graph
52+
53+
def force_recompute_node(node):
54+
if reshard_after_forward:
55+
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
56+
else:
57+
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
58+
# ac_graph_id is used in the partitioner to decide
59+
# if two nodes which have AC applied come from a different
60+
# AC regions. This is needed because nodes in the boundary
61+
# of two AC regions are marked as MUST_SAVE. In our case
62+
# we just add a large value of ac_graph_id so that
63+
# all nodes we tag for recomputation do indeed get recomputed
64+
# and are not influenced by other nodes in the graph with
65+
# nearby ac_graph_id values
66+
node.meta["ac_graph_id"] = 100000
67+
68+
# Make all-gather nodes (and related nodes) recomputable, to circumvent
69+
# https://github.com/pytorch/pytorch/issues/136433
70+
for node in graph.nodes:
71+
if is_wait_tensor_from_fsdp(node):
72+
ag_node = node.args[0]
73+
force_recompute_node(ag_node) # all_gather
74+
force_recompute_node(node) # wait_tensor
75+
# Force-recompute slice that comes after wait
76+
for user in node.users:
77+
if (
78+
user.op == "call_function"
79+
and user.target == torch.ops.aten.slice.Tensor
80+
):
81+
force_recompute_node(user)
82+
# Force-recompute potential dtype casts from all_gather
83+
if (
84+
ag_node.all_input_nodes[0].op == "call_function"
85+
and ag_node.args[0].target
86+
== torch.ops.prims.convert_element_type.default
87+
):
88+
force_recompute_node(ag_node.all_input_nodes[0])
89+
90+
return gm

0 commit comments

Comments
 (0)