22
22
_layer_sac_count = 0
23
23
24
24
25
- def _apply_layer_sac (
26
- module : nn .Module , ac_config : ACConfig , * , ac_freq : int | None = None
27
- ) -> nn .Module :
25
+ def _apply_layer_sac (module : nn .Module , ac_config : ACConfig ) -> nn .Module :
28
26
"""Apply layer selective activation checkpointing to the module.
29
27
30
28
Args:
@@ -58,12 +56,11 @@ def _apply_op_sac(
58
56
module (nn.Module): The module to apply selective activation checkpointing to.
59
57
ac_config (ActivationCheckpoint): The activation checkpointing config.
60
58
base_fqn (str, optional): The base fqn of the module. Defaults to None.
61
- save_list (set[torch._ops.OpOverload]): The list of ops to save when selective
62
- activation checkpointing is used .
59
+ save_list (set[torch._ops.OpOverload]): The list of ops to save instead
60
+ of recomputing .
63
61
64
62
Returns:
65
63
nn.Module: The module with selective activation checkpointing applied.
66
-
67
64
"""
68
65
from torch .utils .checkpoint import (
69
66
CheckpointPolicy ,
@@ -130,59 +127,29 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
130
127
)
131
128
132
129
133
- def _apply_ac_to_transformer_block (
130
+ def _apply_op_sac_to_transformer_block_with_flex (
134
131
module : nn .Module ,
135
132
ac_config : ACConfig ,
136
133
* ,
137
134
base_fqn : str | None = None ,
138
135
model_compile_enabled : bool = False ,
139
- use_flex_attn : bool = False ,
140
- save_list : set [torch ._ops .OpOverload ] | None = None ,
136
+ save_list : set [torch ._ops .OpOverload ],
141
137
) -> nn .Module :
142
- valid_ac_modes = ("full" , "selective" )
143
- if ac_config .mode not in valid_ac_modes :
144
- raise ValueError (
145
- f"Invalid AC mode: { ac_config .mode } . Valid modes: { valid_ac_modes } "
146
- )
138
+ """Apply SAC to the transformer block that uses FlexAttention.
147
139
148
- if ac_config .mode == "full" :
149
- return _apply_full_ac (module , ac_config )
150
-
151
- assert ac_config .mode == "selective" , f"{ ac_config .mode } "
152
- use_op_sac = ac_config .selective_ac_option == "op"
153
- use_layer_sac = ac_config .selective_ac_option .isdigit ()
154
- if not use_op_sac and not use_layer_sac :
155
- raise ValueError (
156
- f"Invalid selective AC option: { ac_config .selective_ac_option } . "
157
- f"Valid options: 'op' or a positive int representing layer frequency"
158
- )
159
-
160
- if use_op_sac :
161
- save_list = save_list or set ()
162
- if use_flex_attn :
163
- return _apply_op_sac_to_transformer_block_with_flex (
164
- module ,
165
- ac_config ,
166
- base_fqn = base_fqn ,
167
- model_compile_enabled = model_compile_enabled ,
168
- save_list = save_list ,
169
- )
170
- else :
171
- return _apply_op_sac (
172
- module , ac_config , base_fqn = base_fqn , save_list = save_list
173
- )
174
-
175
- return _apply_layer_sac (module , ac_config )
140
+ Args:
141
+ module (nn.Module): The transformer block to apply SAC to.
142
+ ac_config (ACConfig): The activation checkpointing config.
143
+ base_fqn (str, optional): The base fqn of the module. Defaults to None.
144
+ model_compile_enabled (bool): Whether model compilation is enabled.
145
+ Defaults to False.
146
+ save_list (set[torch._ops.OpOverload]): The list of ops to save instead
147
+ of recomputing.
176
148
149
+ Returns:
150
+ nn.Module: The transformer block with SAC applied.
151
+ """
177
152
178
- def _apply_op_sac_to_transformer_block_with_flex (
179
- module : nn .Module ,
180
- ac_config : ACConfig ,
181
- * ,
182
- base_fqn : str | None = None ,
183
- model_compile_enabled : bool = False ,
184
- save_list : set [torch ._ops .OpOverload ],
185
- ) -> nn .Module :
186
153
warn_once (
187
154
logger ,
188
155
(
@@ -227,6 +194,51 @@ def _apply_op_sac_to_transformer_block_with_flex(
227
194
return module
228
195
229
196
197
+ def _apply_ac_to_transformer_block (
198
+ module : nn .Module ,
199
+ ac_config : ACConfig ,
200
+ * ,
201
+ base_fqn : str | None = None ,
202
+ model_compile_enabled : bool = False ,
203
+ use_flex_attn : bool = False ,
204
+ save_list : set [torch ._ops .OpOverload ] | None = None ,
205
+ ) -> nn .Module :
206
+ valid_ac_modes = ("full" , "selective" )
207
+ if ac_config .mode not in valid_ac_modes :
208
+ raise ValueError (
209
+ f"Invalid AC mode: { ac_config .mode } . Valid modes: { valid_ac_modes } "
210
+ )
211
+
212
+ if ac_config .mode == "full" :
213
+ return _apply_full_ac (module , ac_config )
214
+
215
+ assert ac_config .mode == "selective" , f"{ ac_config .mode } "
216
+ use_op_sac = ac_config .selective_ac_option == "op"
217
+ use_layer_sac = ac_config .selective_ac_option .isdigit ()
218
+ if not use_op_sac and not use_layer_sac :
219
+ raise ValueError (
220
+ f"Invalid selective AC option: { ac_config .selective_ac_option } . "
221
+ f"Valid options: 'op' or a positive int representing layer frequency"
222
+ )
223
+
224
+ if use_op_sac :
225
+ save_list = save_list or set ()
226
+ if use_flex_attn :
227
+ return _apply_op_sac_to_transformer_block_with_flex (
228
+ module ,
229
+ ac_config ,
230
+ base_fqn = base_fqn ,
231
+ model_compile_enabled = model_compile_enabled ,
232
+ save_list = save_list ,
233
+ )
234
+ else :
235
+ return _apply_op_sac (
236
+ module , ac_config , base_fqn = base_fqn , save_list = save_list
237
+ )
238
+
239
+ return _apply_layer_sac (module , ac_config )
240
+
241
+
230
242
def apply_ac (
231
243
model : nn .Module ,
232
244
ac_config : ACConfig ,
@@ -238,15 +250,16 @@ def apply_ac(
238
250
"""Apply activation checkpointing to the model.
239
251
240
252
Note that SAC, Flex Attention and model compilation have some conflicts.
241
- We explicitly ask the user to pass these configs to warn if there are conflicts.
253
+ We explicitly ask the user to pass these configs to warn as the wrapping
254
+ will be different.
242
255
243
256
Args:
244
257
model (nn.Module): The model to apply activation checkpointing to.
245
258
ac_config (ActivationCheckpoint): The activation checkpointing config.
246
259
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
247
260
use_flex_attn (bool): Whether flex attention is enabled for the model.
248
- save_list (set[torch._ops.OpOverload]): The list of ops to save when selective
249
- activation checkpointing is used .
261
+ save_list (set[torch._ops.OpOverload]): The list of ops to save instead
262
+ of recomputing .
250
263
Returns:
251
264
None
252
265
"""
0 commit comments