Skip to content

Commit 284695c

Browse files
committed
[simplefsdp] fix region ac in zero 2
1 parent ff07852 commit 284695c

File tree

3 files changed

+375
-17
lines changed

3 files changed

+375
-17
lines changed
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file provides the util functions to apply activation checkpointing to the model.
8+
# Technically, this is not a part of distributed, but distributed module is the best place to put it.
9+
10+
from collections import defaultdict
11+
12+
import torch
13+
import torch.nn as nn
14+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
15+
checkpoint_wrapper as ptd_checkpoint_wrapper,
16+
)
17+
18+
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
19+
from torchtitan.tools.logging import logger, warn_once
20+
21+
22+
_layer_sac_count = 0
23+
24+
# for avoid recomputing SimpleFSDP all_gather in zero2-style FSDP
25+
# it enforces additional policy to always mark SimpleFSDP all_gather as PREFER_SAVE
26+
_op_simple_fsdp_save_list = {
27+
torch.ops._c10d_functional.all_gather_into_tensor.default,
28+
torch.ops._c10d_functional.wait_tensor.default,
29+
torch.ops.aten._to_copy.default,
30+
}
31+
32+
from torch.utils.checkpoint import (
33+
CheckpointPolicy,
34+
create_selective_checkpoint_contexts,
35+
)
36+
37+
38+
def _get_custom_policy():
39+
def _custom_policy(ctx, func, *args, **kwargs):
40+
# Always save critical communication and copy ops
41+
if func in _op_simple_fsdp_save_list:
42+
return CheckpointPolicy.PREFER_SAVE
43+
return CheckpointPolicy.PREFER_RECOMPUTE
44+
45+
return _custom_policy
46+
47+
48+
def simplefsdp_checkpointing_context_fn():
49+
return create_selective_checkpoint_contexts(_get_custom_policy())
50+
51+
52+
def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
53+
"""Apply layer selective activation checkpointing to the module.
54+
55+
Args:
56+
module (nn.Module): The module to apply layer selective activation checkpointing to.
57+
ac_config (ACConfig): The activation checkpointing config.
58+
59+
Returns:
60+
nn.Module: The module with layer selective activation checkpointing applied.
61+
"""
62+
global _layer_sac_count
63+
_layer_sac_count += 1
64+
ac_freq = int(ac_config.selective_ac_option)
65+
if not ac_freq or _layer_sac_count % ac_freq == 0:
66+
return ptd_checkpoint_wrapper(
67+
module,
68+
context_fn=simplefsdp_checkpointing_context_fn,
69+
preserve_rng_state=ac_config.preserve_rng_state,
70+
determinism_check=ac_config.determinism_check,
71+
early_stop=ac_config.early_stop,
72+
debug=ac_config.debug,
73+
)
74+
else:
75+
return module
76+
77+
78+
def _apply_op_sac(
79+
module: nn.Module,
80+
ac_config: ACConfig,
81+
*,
82+
base_fqn: str | None = None,
83+
op_sac_save_list: set[torch._ops.OpOverload],
84+
) -> nn.Module:
85+
"""Apply selective activation checkpointing to the module.
86+
87+
Args:
88+
module (nn.Module): The module to apply selective activation checkpointing to.
89+
ac_config (ACConfig): The activation checkpointing config.
90+
base_fqn (str, optional): The base fqn of the module. Defaults to None.
91+
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
92+
of recomputing.
93+
94+
Returns:
95+
nn.Module: The module with selective activation checkpointing applied.
96+
"""
97+
from torch.utils.checkpoint import (
98+
CheckpointPolicy,
99+
create_selective_checkpoint_contexts,
100+
)
101+
102+
mm_recompute_shapes = set()
103+
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
104+
for module_fqn, submod in module.named_modules():
105+
fqn = module_fqn
106+
if base_fqn is not None:
107+
fqn = f"{base_fqn}.{module_fqn}"
108+
if not any(
109+
filter_fqn in fqn
110+
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
111+
):
112+
continue
113+
if not isinstance(submod, nn.Linear):
114+
raise ValueError(
115+
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
116+
f"a nn.Linear, but got: {submod}"
117+
)
118+
out_f, in_f = submod.weight.shape
119+
mm_recompute_shapes.add((in_f, out_f))
120+
logger.debug(
121+
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
122+
)
123+
124+
def _get_custom_policy(meta):
125+
def _custom_policy(ctx, func, *args, **kwargs):
126+
if (
127+
func == torch.ops.aten._to_copy.default
128+
and "cuda" in str(args[0].device)
129+
and "device" in kwargs
130+
and str(kwargs["device"]) == "cpu"
131+
):
132+
return CheckpointPolicy.MUST_SAVE
133+
134+
if func in _op_simple_fsdp_save_list:
135+
return CheckpointPolicy.PREFER_SAVE
136+
137+
mode = "recompute" if ctx.is_recompute else "forward"
138+
mm_count_key = f"{mode}_mm_count"
139+
if func == torch.ops.aten.mm.default:
140+
if args[1].shape in mm_recompute_shapes:
141+
return CheckpointPolicy.PREFER_RECOMPUTE
142+
meta[mm_count_key] += 1
143+
# Saves output of all compute ops, except every second mm
144+
to_save = func in op_sac_save_list and not (
145+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
146+
)
147+
return (
148+
CheckpointPolicy.MUST_SAVE
149+
if to_save
150+
else CheckpointPolicy.PREFER_RECOMPUTE
151+
)
152+
153+
return _custom_policy
154+
155+
def selective_checkpointing_context_fn():
156+
meta = defaultdict(int)
157+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
158+
159+
return ptd_checkpoint_wrapper(
160+
module,
161+
context_fn=selective_checkpointing_context_fn,
162+
preserve_rng_state=ac_config.preserve_rng_state,
163+
determinism_check=ac_config.determinism_check,
164+
early_stop=ac_config.early_stop,
165+
debug=ac_config.debug,
166+
)
167+
168+
169+
def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
170+
"""Apply full activation checkpointing to the module.
171+
172+
Args:
173+
module (nn.Module): The module to apply full activation checkpointing to.
174+
ac_config (ACConfig): The activation checkpointing config.
175+
176+
Returns:
177+
nn.Module: The module with full activation checkpointing applied.
178+
"""
179+
return ptd_checkpoint_wrapper(
180+
module,
181+
context_fn=simplefsdp_checkpointing_context_fn,
182+
preserve_rng_state=ac_config.preserve_rng_state,
183+
determinism_check=ac_config.determinism_check,
184+
early_stop=ac_config.early_stop,
185+
debug=ac_config.debug,
186+
)
187+
188+
189+
def _apply_op_sac_to_transformer_block_with_flex(
190+
module: nn.Module,
191+
ac_config: ACConfig,
192+
*,
193+
base_fqn: str | None = None,
194+
model_compile_enabled: bool = False,
195+
op_sac_save_list: set[torch._ops.OpOverload],
196+
) -> nn.Module:
197+
"""Apply SAC to the transformer block that uses FlexAttention.
198+
199+
Args:
200+
module (nn.Module): The transformer block to apply SAC to.
201+
ac_config (ACConfig): The Activation Checkpoint config.
202+
base_fqn (str, optional): The base fqn of the module. Defaults to None.
203+
model_compile_enabled (bool): Whether model compilation is enabled.
204+
Defaults to False.
205+
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
206+
of recomputing.
207+
208+
Returns:
209+
nn.Module: The transformer block with SAC applied.
210+
"""
211+
212+
warn_once(
213+
logger,
214+
(
215+
"Flex Attention requires compilation for good performance.\n"
216+
"Thus, torch.compile is always used for Flex Attention, "
217+
"regardless of the compile.enable flag.\n"
218+
"However, when selective activation checkpointing (SAC) is enabled, "
219+
"torch.compile may be invalidated:\n"
220+
"1. If compile.enable is False, SAC will ignore any torch.compile "
221+
"inside the SAC region.\n"
222+
"2. If compile.enable is True but the transformer block contains an MoE module.\n\n"
223+
"For both cases, we will not wrap the entire TransformerBlock with SAC:\n"
224+
" - For case 1: SAC will be used for MoE and FeedForward modules, "
225+
"while full AC will be used for the Attention module.\n"
226+
" - For case 2: SAC will be applied to MoE and Attention modules if the block "
227+
"is sparse. But we still apply SAC to an entire dense block.\n"
228+
),
229+
)
230+
231+
def wrap_submodule(name: str, full_ac: bool = False) -> None:
232+
submodule = getattr(module, name)
233+
if full_ac:
234+
submodule = _apply_full_ac(submodule, ac_config)
235+
else:
236+
submodule = _apply_op_sac(
237+
submodule,
238+
ac_config,
239+
base_fqn=f"{base_fqn}.{name}" if base_fqn else name,
240+
op_sac_save_list=op_sac_save_list,
241+
)
242+
module.register_module(name, submodule)
243+
244+
if hasattr(module, "moe"):
245+
wrap_submodule("moe", full_ac=False)
246+
if model_compile_enabled:
247+
wrap_submodule("attention", full_ac=False)
248+
else:
249+
wrap_submodule("attention", full_ac=True)
250+
else:
251+
if model_compile_enabled:
252+
module = _apply_op_sac(
253+
module,
254+
ac_config,
255+
base_fqn=base_fqn,
256+
op_sac_save_list=op_sac_save_list,
257+
)
258+
else:
259+
wrap_submodule("feed_forward", full_ac=False)
260+
wrap_submodule("attention", full_ac=True)
261+
return module
262+
263+
264+
def _apply_ac_to_transformer_block(
265+
module: nn.Module,
266+
ac_config: ACConfig,
267+
*,
268+
base_fqn: str | None = None,
269+
model_compile_enabled: bool = False,
270+
use_flex_attn: bool = False,
271+
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
272+
) -> nn.Module:
273+
valid_ac_modes = ("full", "selective")
274+
if ac_config.mode not in valid_ac_modes:
275+
raise ValueError(
276+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
277+
)
278+
279+
if ac_config.mode == "full":
280+
return _apply_full_ac(module, ac_config)
281+
282+
assert ac_config.mode == "selective", f"{ac_config.mode}"
283+
use_op_sac = ac_config.selective_ac_option == "op"
284+
use_layer_sac = ac_config.selective_ac_option.isdigit()
285+
if not use_op_sac and not use_layer_sac:
286+
raise ValueError(
287+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
288+
f"Valid options: 'op' or a positive int representing layer frequency"
289+
)
290+
291+
if use_op_sac:
292+
op_sac_save_list = op_sac_save_list or set()
293+
if use_flex_attn:
294+
"""
295+
For Flex Attention, we need to apply SAC carefully to avoid invalidating
296+
torch.compile. Any torch.compile inside the SAC region will be ignored,
297+
and any torch.compile outside the SAC region will also be ignored if the
298+
SAC region contains a graph break (e.g., MoE).
299+
300+
TODO: remove this once SAC issues are resolved.
301+
"""
302+
return _apply_op_sac_to_transformer_block_with_flex(
303+
module,
304+
ac_config,
305+
base_fqn=base_fqn,
306+
model_compile_enabled=model_compile_enabled,
307+
op_sac_save_list=op_sac_save_list,
308+
)
309+
else:
310+
return _apply_op_sac(
311+
module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list
312+
)
313+
314+
return _apply_layer_sac(module, ac_config)
315+
316+
317+
def apply_ac(
318+
model: nn.Module,
319+
ac_config: ACConfig,
320+
*,
321+
model_compile_enabled: bool = False,
322+
use_flex_attn: bool = False,
323+
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
324+
base_folder: str = "",
325+
) -> None:
326+
"""Apply activation checkpointing to the model.
327+
328+
Note that SAC, Flex Attention and model compilation have some conflicts.
329+
We explicitly ask the user to pass these configs to warn as the wrapping
330+
will be different.
331+
332+
Args:
333+
model (nn.Module): The model to apply activation checkpointing to.
334+
ac_config (ACConfig): The activation checkpointing config.
335+
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
336+
use_flex_attn (bool): Whether flex attention is enabled for the model.
337+
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
338+
of recomputing.
339+
Returns:
340+
None
341+
"""
342+
for layer_id, transformer_block in model.layers.named_children():
343+
transformer_block = _apply_ac_to_transformer_block(
344+
transformer_block,
345+
ac_config,
346+
base_fqn=f"layers.{layer_id}",
347+
model_compile_enabled=model_compile_enabled,
348+
use_flex_attn=use_flex_attn,
349+
op_sac_save_list=op_sac_save_list,
350+
)
351+
model.layers.register_module(layer_id, transformer_block)
352+
353+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010

1111
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1212
from torchtitan.distributed import ParallelDims
13+
14+
from torchtitan.distributed.activation_checkpoint import apply_ac
1315
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
1416
from torchtitan.models.deepseek_v3.infra.parallelize import (
15-
apply_ac,
1617
apply_moe_ep_tp,
1718
apply_non_moe_tp,
1819
)
1920
from torchtitan.tools.logging import logger
2021

2122
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
2223

23-
2424
# Adapted from llama4/infra/parallelize.py
2525
def parallelize_deepseekv3(
2626
model: nn.Module,

0 commit comments

Comments
 (0)