Skip to content

Commit 3b6149c

Browse files
committed
[simplefsdp] fix region ac in zero 2
1 parent ff07852 commit 3b6149c

File tree

5 files changed

+165
-60
lines changed

5 files changed

+165
-60
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
8+
#
9+
# This source code is licensed under the BSD license found in the
10+
# LICENSE file in the root directory of this source tree.
11+
import torch
12+
from torch.utils.checkpoint import CheckpointPolicy
13+
14+
15+
def is_graph_input(node: torch.fx.Node) -> bool:
16+
return node.op == "placeholder"
17+
18+
19+
def is_wait_tensor(node: torch.fx.Node) -> bool:
20+
return (
21+
node.op == "call_function"
22+
and node.target == torch.ops._c10d_functional.wait_tensor.default
23+
)
24+
25+
26+
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
27+
return (
28+
node.op == "call_function"
29+
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
30+
)
31+
32+
33+
def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
34+
"""
35+
Returns True if the node is a wait_tensor node that is the result of an all_gather
36+
that can be arbitrarily prefetched, i.e., if all its recursive inputs are
37+
single-input operators that leads to a graph input.
38+
"""
39+
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
40+
n: torch.fx.Node = node.all_input_nodes[0]
41+
while len(n.all_input_nodes) == 1:
42+
if is_graph_input(n.all_input_nodes[0]):
43+
return True
44+
n = n.all_input_nodes[0]
45+
return False
46+
47+
48+
def annotate_fsdp_all_gather(
49+
gm: torch.fx.GraphModule, reshard_after_forward: bool
50+
) -> None:
51+
"""
52+
Force recompute all_gather nodes from simple fsdp in the graph.
53+
This pass should be added in torch._inductor.config.joint_custom_post_pass
54+
"""
55+
graph = gm.graph
56+
57+
def force_recompute_node(node):
58+
if reshard_after_forward:
59+
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
60+
else:
61+
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
62+
# ac_graph_id is used in the partitioner to decide
63+
# if two nodes which have AC applied come from a different
64+
# AC regions. This is needed because nodes in the boundary
65+
# of two AC regions are marked as MUST_SAVE. In our case
66+
# we just add a large value of ac_graph_id so that
67+
# all nodes we tag for recomputation do indeed get recomputed
68+
# and are not influenced by other nodes in the graph with
69+
# nearby ac_graph_id values
70+
node.meta["ac_graph_id"] = 0
71+
72+
# Make all-gather nodes (and related nodes) recomputable, to circumvent
73+
# https://github.com/pytorch/pytorch/issues/136433
74+
for node in graph.nodes:
75+
if is_wait_tensor_from_fsdp(node):
76+
ag_node = node.args[0]
77+
force_recompute_node(ag_node) # all_gather
78+
force_recompute_node(node) # wait_tensor
79+
# Force-recompute slice that comes after wait
80+
for user in node.users:
81+
if (
82+
user.op == "call_function"
83+
and user.target == torch.ops.aten.slice.Tensor
84+
):
85+
force_recompute_node(user)
86+
# Force-recompute potential dtype casts from all_gather
87+
if (
88+
ag_node.all_input_nodes[0].op == "call_function"
89+
and ag_node.args[0].target
90+
== torch.ops.prims.convert_element_type.default
91+
):
92+
force_recompute_node(ag_node.all_input_nodes[0])
93+
94+
return gm

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,23 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Union
7+
from typing import Any
88

99
import torch
10+
import torch._functorch.config as functorch_config
1011

12+
from .activation_checkpoint import annotate_fsdp_all_gather
1113

12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
14+
15+
def get_compile_backend(backend_name: str, reshard_after_forward: bool) -> callable:
1316
# return the compile backends used in SimpleFSDP training
1417
# Step1: check if backend_name is inside available torch.compile backends
1518
# Step2: check if the backend_name has been registered as a customized backend
1619
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
17-
if backend_name in available_torch_backend:
18-
return backend_name
1920

20-
if backend_name == "aot_eager_autobucketing":
21+
if backend_name in available_torch_backend:
22+
backend = torch._dynamo.lookup_backend(backend_name)
23+
elif backend_name == "aot_eager_autobucketing":
2124
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
2225
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
2326
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
@@ -46,4 +49,22 @@ def aten_autobucketing_reordering_pass(
4649
else:
4750
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4851

49-
return backend
52+
def joint_ac_pass(
53+
gm: torch.fx.GraphModule, example_inputs: Any
54+
) -> torch.fx.GraphModule:
55+
# this pass implements simplefsdp's reshard_after_forward behavior
56+
# when reshard_after_forward set to True, it will annotate simple_fsdp AG
57+
# to CheckpointPolicy.MUST_RECOMPUTE.
58+
# when reshard_after_forward set to False, it will annotate simple_fsdp AG
59+
# to CheckpointPolicy.MUST_SAVE.
60+
gm = annotate_fsdp_all_gather(gm, reshard_after_forward)
61+
gm.recompile()
62+
return gm
63+
64+
def simple_fsdp_custom_pass(*args, **kwargs):
65+
# the ac pass has to operate in a joint graph before partitioner for ac
66+
# annotation to take into effect.
67+
with functorch_config.patch("joint_custom_pass", joint_ac_pass):
68+
return backend(*args, **kwargs)
69+
70+
return simple_fsdp_custom_pass

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010

1111
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1212
from torchtitan.distributed import ParallelDims
13+
14+
from torchtitan.distributed.activation_checkpoint import apply_ac
1315
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
1416
from torchtitan.models.deepseek_v3.infra.parallelize import (
15-
apply_ac,
1617
apply_moe_ep_tp,
1718
apply_non_moe_tp,
1819
)
1920
from torchtitan.tools.logging import logger
2021

21-
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
22+
from ..backend import get_compile_backend
2223

24+
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
2325

2426
# Adapted from llama4/infra/parallelize.py
2527
def parallelize_deepseekv3(
@@ -91,20 +93,6 @@ def parallelize_deepseekv3(
9193
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
9294
)
9395

94-
match job_config.parallelism.fsdp_reshard_after_forward:
95-
case "always":
96-
reshard_after_forward = True
97-
case "never":
98-
reshard_after_forward = False
99-
case "default":
100-
# For PP, by default do not reshard after forward to avoid per-microbatch
101-
# all-gathers, which can be expensive and non-overlapped
102-
reshard_after_forward = not parallel_dims.pp_enabled
103-
case _:
104-
raise ValueError(
105-
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
106-
)
107-
10896
# apply data parallel
10997
dp_mesh: DeviceMesh | None = None
11098
if (
@@ -157,7 +145,6 @@ def parallelize_deepseekv3(
157145
dp_mode,
158146
ac_mode=job_config.activation_checkpoint.mode,
159147
mp_policy=mp_policy,
160-
reshard_after_forward=reshard_after_forward,
161148
shard_dim=experts_shard_dim,
162149
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
163150
)
@@ -168,7 +155,6 @@ def parallelize_deepseekv3(
168155
dp_mode,
169156
ac_mode=job_config.activation_checkpoint.mode,
170157
mp_policy=mp_policy,
171-
reshard_after_forward=reshard_after_forward,
172158
)
173159

174160
logger.info(
@@ -178,6 +164,29 @@ def parallelize_deepseekv3(
178164
if job_config.compile.enable:
179165
torch._inductor.config.reorder_for_peak_memory = False
180166
torch._dynamo.config.capture_scalar_outputs = True
181-
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
167+
168+
match job_config.parallelism.fsdp_reshard_after_forward:
169+
case "always":
170+
reshard_after_forward = True
171+
case "never":
172+
reshard_after_forward = False
173+
case "default":
174+
# For PP, by default do not reshard after forward to avoid per-microbatch
175+
# all-gathers, which can be expensive and non-overlapped
176+
reshard_after_forward = not parallel_dims.pp_enabled
177+
case _:
178+
raise ValueError(
179+
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
180+
)
181+
182+
backend = (
183+
getattr(job_config.compile, "model_backend_override", None)
184+
or job_config.compile.backend
185+
)
186+
model = torch.compile(
187+
model,
188+
backend=get_compile_backend(backend, reshard_after_forward),
189+
fullgraph=True,
190+
)
182191

183192
return model

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ def parallelize_llama(
112112
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
113113
)
114114

115+
model = data_parallel(
116+
model,
117+
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
118+
mode=dp_mode,
119+
ac_mode=job_config.activation_checkpoint.mode,
120+
mp_policy=mp_policy,
121+
)
122+
logger.info(
123+
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
124+
)
125+
126+
if job_config.compile.enable and "model" in job_config.compile.components:
127+
torch._inductor.config.reorder_for_peak_memory = False
128+
115129
match job_config.parallelism.fsdp_reshard_after_forward:
116130
case "always":
117131
reshard_after_forward = True
@@ -126,27 +140,13 @@ def parallelize_llama(
126140
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
127141
)
128142

129-
model = data_parallel(
130-
model,
131-
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
132-
mode=dp_mode,
133-
ac_mode=job_config.activation_checkpoint.mode,
134-
mp_policy=mp_policy,
135-
reshard_after_forward=reshard_after_forward,
136-
)
137-
logger.info(
138-
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
139-
)
140-
141-
if job_config.compile.enable and "model" in job_config.compile.components:
142-
torch._inductor.config.reorder_for_peak_memory = False
143143
backend = (
144144
getattr(job_config.compile, "model_backend_override", None)
145145
or job_config.compile.backend
146146
)
147147
model = torch.compile(
148148
model,
149-
backend=get_compile_backend(backend),
149+
backend=get_compile_backend(backend, reshard_after_forward),
150150
fullgraph=True,
151151
)
152152

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2424
from torch.distributed.tensor.placement_types import _StridedShard, Placement
2525
from torch.utils.checkpoint import (
26-
checkpoint,
2726
CheckpointPolicy,
2827
create_selective_checkpoint_contexts,
2928
)
@@ -210,7 +209,6 @@ def __init__(
210209
mode,
211210
regional_ac,
212211
mp_policy,
213-
reshard_after_forward,
214212
reduction_divide_factor,
215213
):
216214
super().__init__()
@@ -229,7 +227,6 @@ def __init__(
229227
mp_policy = mp_policy or MixedPrecisionPolicy()
230228
self.param_dtype = mp_policy.param_dtype
231229
self.reduce_dtype = mp_policy.reduce_dtype
232-
self.reshard_after_forward = reshard_after_forward
233230

234231
def replicate_compute(self, x: DTensor) -> torch.Tensor:
235232
# data parallel runtime replicate parameters and do local compute
@@ -292,21 +289,7 @@ def forward(self, x: DTensor) -> torch.Tensor:
292289
if not _active_parametrization:
293290
return x
294291

295-
if (
296-
self.regional_ac
297-
and self.mode in ("fully_shard", "hybrid_shard")
298-
and self.reshard_after_forward
299-
):
300-
# apply checkpointing to implement reshard_after_forward
301-
output = checkpoint(
302-
self.replicate_compute,
303-
x,
304-
use_reentrant=False,
305-
context_fn=fsdp_policy,
306-
)
307-
else:
308-
output = self.replicate_compute(x)
309-
292+
output = self.replicate_compute(x)
310293
return output
311294

312295

@@ -316,7 +299,6 @@ def data_parallel(
316299
mode: str = "replicate",
317300
ac_mode: str = "none",
318301
mp_policy: MixedPrecisionPolicy | None = None,
319-
reshard_after_forward: bool = True,
320302
shard_dim: int = 0,
321303
reduction_divide_factor: float | None = None,
322304
):
@@ -381,7 +363,6 @@ def data_parallel(
381363
mode,
382364
regional_ac,
383365
mp_policy=mp_policy,
384-
reshard_after_forward=reshard_after_forward,
385366
reduction_divide_factor=reduction_divide_factor,
386367
),
387368
)

0 commit comments

Comments
 (0)