Skip to content

Commit 8fed32a

Browse files
committed
misc
1 parent 1b3d37d commit 8fed32a

File tree

1 file changed

+66
-53
lines changed

1 file changed

+66
-53
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
_layer_sac_count = 0
2323

2424

25-
def _apply_layer_sac(
26-
module: nn.Module, ac_config: ACConfig, *, ac_freq: int | None = None
27-
) -> nn.Module:
25+
def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
2826
"""Apply layer selective activation checkpointing to the module.
2927
3028
Args:
@@ -58,12 +56,11 @@ def _apply_op_sac(
5856
module (nn.Module): The module to apply selective activation checkpointing to.
5957
ac_config (ActivationCheckpoint): The activation checkpointing config.
6058
base_fqn (str, optional): The base fqn of the module. Defaults to None.
61-
save_list (set[torch._ops.OpOverload]): The list of ops to save when selective
62-
activation checkpointing is used.
59+
save_list (set[torch._ops.OpOverload]): The list of ops to save instead
60+
of recomputing.
6361
6462
Returns:
6563
nn.Module: The module with selective activation checkpointing applied.
66-
6764
"""
6865
from torch.utils.checkpoint import (
6966
CheckpointPolicy,
@@ -130,59 +127,29 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
130127
)
131128

132129

133-
def _apply_ac_to_transformer_block(
130+
def _apply_op_sac_to_transformer_block_with_flex(
134131
module: nn.Module,
135132
ac_config: ACConfig,
136133
*,
137134
base_fqn: str | None = None,
138135
model_compile_enabled: bool = False,
139-
use_flex_attn: bool = False,
140-
save_list: set[torch._ops.OpOverload] | None = None,
136+
save_list: set[torch._ops.OpOverload],
141137
) -> nn.Module:
142-
valid_ac_modes = ("full", "selective")
143-
if ac_config.mode not in valid_ac_modes:
144-
raise ValueError(
145-
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
146-
)
138+
"""Apply SAC to the transformer block that uses FlexAttention.
147139
148-
if ac_config.mode == "full":
149-
return _apply_full_ac(module, ac_config)
150-
151-
assert ac_config.mode == "selective", f"{ac_config.mode}"
152-
use_op_sac = ac_config.selective_ac_option == "op"
153-
use_layer_sac = ac_config.selective_ac_option.isdigit()
154-
if not use_op_sac and not use_layer_sac:
155-
raise ValueError(
156-
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
157-
f"Valid options: 'op' or a positive int representing layer frequency"
158-
)
159-
160-
if use_op_sac:
161-
save_list = save_list or set()
162-
if use_flex_attn:
163-
return _apply_op_sac_to_transformer_block_with_flex(
164-
module,
165-
ac_config,
166-
base_fqn=base_fqn,
167-
model_compile_enabled=model_compile_enabled,
168-
save_list=save_list,
169-
)
170-
else:
171-
return _apply_op_sac(
172-
module, ac_config, base_fqn=base_fqn, save_list=save_list
173-
)
174-
175-
return _apply_layer_sac(module, ac_config)
140+
Args:
141+
module (nn.Module): The transformer block to apply SAC to.
142+
ac_config (ACConfig): The activation checkpointing config.
143+
base_fqn (str, optional): The base fqn of the module. Defaults to None.
144+
model_compile_enabled (bool): Whether model compilation is enabled.
145+
Defaults to False.
146+
save_list (set[torch._ops.OpOverload]): The list of ops to save instead
147+
of recomputing.
176148
149+
Returns:
150+
nn.Module: The transformer block with SAC applied.
151+
"""
177152

178-
def _apply_op_sac_to_transformer_block_with_flex(
179-
module: nn.Module,
180-
ac_config: ACConfig,
181-
*,
182-
base_fqn: str | None = None,
183-
model_compile_enabled: bool = False,
184-
save_list: set[torch._ops.OpOverload],
185-
) -> nn.Module:
186153
warn_once(
187154
logger,
188155
(
@@ -227,6 +194,51 @@ def _apply_op_sac_to_transformer_block_with_flex(
227194
return module
228195

229196

197+
def _apply_ac_to_transformer_block(
198+
module: nn.Module,
199+
ac_config: ACConfig,
200+
*,
201+
base_fqn: str | None = None,
202+
model_compile_enabled: bool = False,
203+
use_flex_attn: bool = False,
204+
save_list: set[torch._ops.OpOverload] | None = None,
205+
) -> nn.Module:
206+
valid_ac_modes = ("full", "selective")
207+
if ac_config.mode not in valid_ac_modes:
208+
raise ValueError(
209+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
210+
)
211+
212+
if ac_config.mode == "full":
213+
return _apply_full_ac(module, ac_config)
214+
215+
assert ac_config.mode == "selective", f"{ac_config.mode}"
216+
use_op_sac = ac_config.selective_ac_option == "op"
217+
use_layer_sac = ac_config.selective_ac_option.isdigit()
218+
if not use_op_sac and not use_layer_sac:
219+
raise ValueError(
220+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
221+
f"Valid options: 'op' or a positive int representing layer frequency"
222+
)
223+
224+
if use_op_sac:
225+
save_list = save_list or set()
226+
if use_flex_attn:
227+
return _apply_op_sac_to_transformer_block_with_flex(
228+
module,
229+
ac_config,
230+
base_fqn=base_fqn,
231+
model_compile_enabled=model_compile_enabled,
232+
save_list=save_list,
233+
)
234+
else:
235+
return _apply_op_sac(
236+
module, ac_config, base_fqn=base_fqn, save_list=save_list
237+
)
238+
239+
return _apply_layer_sac(module, ac_config)
240+
241+
230242
def apply_ac(
231243
model: nn.Module,
232244
ac_config: ACConfig,
@@ -238,15 +250,16 @@ def apply_ac(
238250
"""Apply activation checkpointing to the model.
239251
240252
Note that SAC, Flex Attention and model compilation have some conflicts.
241-
We explicitly ask the user to pass these configs to warn if there are conflicts.
253+
We explicitly ask the user to pass these configs to warn as the wrapping
254+
will be different.
242255
243256
Args:
244257
model (nn.Module): The model to apply activation checkpointing to.
245258
ac_config (ActivationCheckpoint): The activation checkpointing config.
246259
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
247260
use_flex_attn (bool): Whether flex attention is enabled for the model.
248-
save_list (set[torch._ops.OpOverload]): The list of ops to save when selective
249-
activation checkpointing is used.
261+
save_list (set[torch._ops.OpOverload]): The list of ops to save instead
262+
of recomputing.
250263
Returns:
251264
None
252265
"""

0 commit comments

Comments
 (0)