Skip to content

Commit 429bbb7

Browse files
committed
[simplefsdp] fix region ac in zero 2
1 parent ff07852 commit 429bbb7

File tree

3 files changed

+349
-3
lines changed

3 files changed

+349
-3
lines changed
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file provides the util functions to apply activation checkpointing to the model.
8+
# Technically, this is not a part of distributed, but distributed module is the best place to put it.
9+
10+
import os
11+
from collections import defaultdict
12+
13+
import torch
14+
import torch.nn as nn
15+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16+
checkpoint_wrapper as ptd_checkpoint_wrapper,
17+
)
18+
19+
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
20+
from torchtitan.tools.logging import logger, warn_once
21+
22+
23+
_layer_sac_count = 0
24+
25+
# for avoid recomputing simple fsdp AG in zero2
26+
_op_simple_fsdp_save_list = {
27+
torch.ops._c10d_functional.all_gather_into_tensor.default,
28+
torch.ops._c10d_functional.wait_tensor.default,
29+
torch.ops.aten._to_copy.default,
30+
}
31+
32+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
33+
def _get_custom_policy():
34+
def _custom_policy(ctx, func, *args, **kwargs):
35+
# Always save critical communication and copy ops
36+
if func in _op_simple_fsdp_save_list:
37+
return CheckpointPolicy.PREFER_SAVE
38+
return CheckpointPolicy.PREFER_RECOMPUTE
39+
40+
return _custom_policy
41+
42+
def simplefsdp_checkpointing_context_fn():
43+
return create_selective_checkpoint_contexts(_get_custom_policy())
44+
45+
46+
def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
47+
"""Apply layer selective activation checkpointing to the module.
48+
49+
Args:
50+
module (nn.Module): The module to apply layer selective activation checkpointing to.
51+
ac_config (ACConfig): The activation checkpointing config.
52+
53+
Returns:
54+
nn.Module: The module with layer selective activation checkpointing applied.
55+
"""
56+
global _layer_sac_count
57+
_layer_sac_count += 1
58+
ac_freq = int(ac_config.selective_ac_option)
59+
if not ac_freq or _layer_sac_count % ac_freq == 0:
60+
return ptd_checkpoint_wrapper(
61+
module,
62+
context_fn=simplefsdp_checkpointing_context_fn,
63+
preserve_rng_state=ac_config.preserve_rng_state,
64+
determinism_check=ac_config.determinism_check,
65+
early_stop=ac_config.early_stop,
66+
debug=ac_config.debug,
67+
)
68+
else:
69+
return module
70+
71+
72+
def _apply_op_sac(
73+
module: nn.Module,
74+
ac_config: ACConfig,
75+
*,
76+
base_fqn: str | None = None,
77+
op_sac_save_list: set[torch._ops.OpOverload],
78+
) -> nn.Module:
79+
"""Apply selective activation checkpointing to the module.
80+
81+
Args:
82+
module (nn.Module): The module to apply selective activation checkpointing to.
83+
ac_config (ACConfig): The activation checkpointing config.
84+
base_fqn (str, optional): The base fqn of the module. Defaults to None.
85+
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
86+
of recomputing.
87+
88+
Returns:
89+
nn.Module: The module with selective activation checkpointing applied.
90+
"""
91+
from torch.utils.checkpoint import (
92+
CheckpointPolicy,
93+
create_selective_checkpoint_contexts,
94+
)
95+
96+
mm_recompute_shapes = set()
97+
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
98+
for module_fqn, submod in module.named_modules():
99+
fqn = module_fqn
100+
if base_fqn is not None:
101+
fqn = f"{base_fqn}.{module_fqn}"
102+
if not any(
103+
filter_fqn in fqn
104+
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
105+
):
106+
continue
107+
if not isinstance(submod, nn.Linear):
108+
raise ValueError(
109+
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
110+
f"a nn.Linear, but got: {submod}"
111+
)
112+
out_f, in_f = submod.weight.shape
113+
mm_recompute_shapes.add((in_f, out_f))
114+
logger.debug(
115+
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
116+
)
117+
118+
def _get_custom_policy(meta):
119+
def _custom_policy(ctx, func, *args, **kwargs):
120+
if (
121+
func == torch.ops.aten._to_copy.default
122+
and "cuda" in str(args[0].device)
123+
and "device" in kwargs
124+
and str(kwargs["device"]) == "cpu"
125+
):
126+
return CheckpointPolicy.MUST_SAVE
127+
128+
if func in _op_simple_fsdp_save_list:
129+
return CheckpointPolicy.PREFER_SAVE
130+
131+
mode = "recompute" if ctx.is_recompute else "forward"
132+
mm_count_key = f"{mode}_mm_count"
133+
if func == torch.ops.aten.mm.default:
134+
if args[1].shape in mm_recompute_shapes:
135+
return CheckpointPolicy.PREFER_RECOMPUTE
136+
meta[mm_count_key] += 1
137+
# Saves output of all compute ops, except every second mm
138+
to_save = func in op_sac_save_list and not (
139+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
140+
)
141+
return (
142+
CheckpointPolicy.MUST_SAVE
143+
if to_save
144+
else CheckpointPolicy.PREFER_RECOMPUTE
145+
)
146+
147+
return _custom_policy
148+
149+
def selective_checkpointing_context_fn():
150+
meta = defaultdict(int)
151+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
152+
153+
return ptd_checkpoint_wrapper(
154+
module,
155+
context_fn=selective_checkpointing_context_fn,
156+
preserve_rng_state=ac_config.preserve_rng_state,
157+
determinism_check=ac_config.determinism_check,
158+
early_stop=ac_config.early_stop,
159+
debug=ac_config.debug,
160+
)
161+
162+
163+
def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
164+
"""Apply full activation checkpointing to the module.
165+
166+
Args:
167+
module (nn.Module): The module to apply full activation checkpointing to.
168+
ac_config (ACConfig): The activation checkpointing config.
169+
170+
Returns:
171+
nn.Module: The module with full activation checkpointing applied.
172+
"""
173+
return ptd_checkpoint_wrapper(
174+
module,
175+
context_fn=simplefsdp_checkpointing_context_fn,
176+
preserve_rng_state=ac_config.preserve_rng_state,
177+
determinism_check=ac_config.determinism_check,
178+
early_stop=ac_config.early_stop,
179+
debug=ac_config.debug,
180+
)
181+
182+
183+
def _apply_op_sac_to_transformer_block_with_flex(
184+
module: nn.Module,
185+
ac_config: ACConfig,
186+
*,
187+
base_fqn: str | None = None,
188+
model_compile_enabled: bool = False,
189+
op_sac_save_list: set[torch._ops.OpOverload],
190+
) -> nn.Module:
191+
"""Apply SAC to the transformer block that uses FlexAttention.
192+
193+
Args:
194+
module (nn.Module): The transformer block to apply SAC to.
195+
ac_config (ACConfig): The Activation Checkpoint config.
196+
base_fqn (str, optional): The base fqn of the module. Defaults to None.
197+
model_compile_enabled (bool): Whether model compilation is enabled.
198+
Defaults to False.
199+
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
200+
of recomputing.
201+
202+
Returns:
203+
nn.Module: The transformer block with SAC applied.
204+
"""
205+
206+
warn_once(
207+
logger,
208+
(
209+
"Flex Attention requires compilation for good performance.\n"
210+
"Thus, torch.compile is always used for Flex Attention, "
211+
"regardless of the compile.enable flag.\n"
212+
"However, when selective activation checkpointing (SAC) is enabled, "
213+
"torch.compile may be invalidated:\n"
214+
"1. If compile.enable is False, SAC will ignore any torch.compile "
215+
"inside the SAC region.\n"
216+
"2. If compile.enable is True but the transformer block contains an MoE module.\n\n"
217+
"For both cases, we will not wrap the entire TransformerBlock with SAC:\n"
218+
" - For case 1: SAC will be used for MoE and FeedForward modules, "
219+
"while full AC will be used for the Attention module.\n"
220+
" - For case 2: SAC will be applied to MoE and Attention modules if the block "
221+
"is sparse. But we still apply SAC to an entire dense block.\n"
222+
),
223+
)
224+
225+
def wrap_submodule(name: str, full_ac: bool = False) -> None:
226+
submodule = getattr(module, name)
227+
if full_ac:
228+
submodule = _apply_full_ac(submodule, ac_config)
229+
else:
230+
submodule = _apply_op_sac(
231+
submodule,
232+
ac_config,
233+
base_fqn=f"{base_fqn}.{name}" if base_fqn else name,
234+
op_sac_save_list=op_sac_save_list,
235+
)
236+
module.register_module(name, submodule)
237+
238+
if hasattr(module, "moe"):
239+
wrap_submodule("moe", full_ac=False)
240+
if model_compile_enabled:
241+
wrap_submodule("attention", full_ac=False)
242+
else:
243+
wrap_submodule("attention", full_ac=True)
244+
else:
245+
if model_compile_enabled:
246+
module = _apply_op_sac(
247+
module,
248+
ac_config,
249+
base_fqn=base_fqn,
250+
op_sac_save_list=op_sac_save_list,
251+
)
252+
else:
253+
wrap_submodule("feed_forward", full_ac=False)
254+
wrap_submodule("attention", full_ac=True)
255+
return module
256+
257+
258+
def _apply_ac_to_transformer_block(
259+
module: nn.Module,
260+
ac_config: ACConfig,
261+
*,
262+
base_fqn: str | None = None,
263+
model_compile_enabled: bool = False,
264+
use_flex_attn: bool = False,
265+
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
266+
) -> nn.Module:
267+
valid_ac_modes = ("full", "selective")
268+
if ac_config.mode not in valid_ac_modes:
269+
raise ValueError(
270+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
271+
)
272+
273+
if ac_config.mode == "full":
274+
return _apply_full_ac(module, ac_config)
275+
276+
assert ac_config.mode == "selective", f"{ac_config.mode}"
277+
use_op_sac = ac_config.selective_ac_option == "op"
278+
use_layer_sac = ac_config.selective_ac_option.isdigit()
279+
if not use_op_sac and not use_layer_sac:
280+
raise ValueError(
281+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
282+
f"Valid options: 'op' or a positive int representing layer frequency"
283+
)
284+
285+
if use_op_sac:
286+
op_sac_save_list = op_sac_save_list or set()
287+
if use_flex_attn:
288+
"""
289+
For Flex Attention, we need to apply SAC carefully to avoid invalidating
290+
torch.compile. Any torch.compile inside the SAC region will be ignored,
291+
and any torch.compile outside the SAC region will also be ignored if the
292+
SAC region contains a graph break (e.g., MoE).
293+
294+
TODO: remove this once SAC issues are resolved.
295+
"""
296+
return _apply_op_sac_to_transformer_block_with_flex(
297+
module,
298+
ac_config,
299+
base_fqn=base_fqn,
300+
model_compile_enabled=model_compile_enabled,
301+
op_sac_save_list=op_sac_save_list,
302+
)
303+
else:
304+
return _apply_op_sac(
305+
module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list
306+
)
307+
308+
return _apply_layer_sac(module, ac_config)
309+
310+
311+
def apply_ac(
312+
model: nn.Module,
313+
ac_config: ACConfig,
314+
*,
315+
model_compile_enabled: bool = False,
316+
use_flex_attn: bool = False,
317+
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
318+
base_folder: str = "",
319+
) -> None:
320+
"""Apply activation checkpointing to the model.
321+
322+
Note that SAC, Flex Attention and model compilation have some conflicts.
323+
We explicitly ask the user to pass these configs to warn as the wrapping
324+
will be different.
325+
326+
Args:
327+
model (nn.Module): The model to apply activation checkpointing to.
328+
ac_config (ACConfig): The activation checkpointing config.
329+
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
330+
use_flex_attn (bool): Whether flex attention is enabled for the model.
331+
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
332+
of recomputing.
333+
Returns:
334+
None
335+
"""
336+
for layer_id, transformer_block in model.layers.named_children():
337+
transformer_block = _apply_ac_to_transformer_block(
338+
transformer_block,
339+
ac_config,
340+
base_fqn=f"layers.{layer_id}",
341+
model_compile_enabled=model_compile_enabled,
342+
use_flex_attn=use_flex_attn,
343+
op_sac_save_list=op_sac_save_list,
344+
)
345+
model.layers.register_module(layer_id, transformer_block)
346+
347+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
from torchtitan.distributed import ParallelDims
1313
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
1414
from torchtitan.models.deepseek_v3.infra.parallelize import (
15-
apply_ac,
1615
apply_moe_ep_tp,
1716
apply_non_moe_tp,
1817
)
1918
from torchtitan.tools.logging import logger
2019

2120
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
22-
21+
from ..activation_checkpoint import apply_ac
2322

2423
# Adapted from llama4/infra/parallelize.py
2524
def parallelize_deepseekv3(

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1111
from torchtitan.distributed import ParallelDims
12-
from torchtitan.distributed.activation_checkpoint import apply_ac
1312
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
1413
from torchtitan.models.llama3.infra.parallelize import apply_tp
1514
from torchtitan.tools.logging import logger
1615

1716
from ..backend import get_compile_backend
1817

1918
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
19+
from ..activation_checkpoint import apply_ac
2020

2121

2222
# for selective op activation checkpointing

0 commit comments

Comments
 (0)