|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# This file provides the util functions to apply activation checkpointing to the model. |
| 8 | +# Technically, this is not a part of distributed, but distributed module is the best place to put it. |
| 9 | + |
| 10 | +from collections import defaultdict |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| 15 | + checkpoint_wrapper as ptd_checkpoint_wrapper, |
| 16 | +) |
| 17 | + |
| 18 | +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig |
| 19 | +from torchtitan.tools.logging import logger, warn_once |
| 20 | + |
| 21 | + |
| 22 | +_layer_sac_count = 0 |
| 23 | + |
| 24 | +# for avoid recomputing SimpleFSDP all_gather in zero2-style FSDP |
| 25 | +# it enforces additional policy to always mark SimpleFSDP all_gather as PREFER_SAVE |
| 26 | +_op_simple_fsdp_save_list = { |
| 27 | + torch.ops._c10d_functional.all_gather_into_tensor.default, |
| 28 | + torch.ops._c10d_functional.wait_tensor.default, |
| 29 | + torch.ops.aten._to_copy.default, |
| 30 | +} |
| 31 | + |
| 32 | +from torch.utils.checkpoint import ( |
| 33 | + CheckpointPolicy, |
| 34 | + create_selective_checkpoint_contexts, |
| 35 | +) |
| 36 | + |
| 37 | + |
| 38 | +def _get_custom_simplefsdp_policy(): |
| 39 | + def _custom_policy(ctx, func, *args, **kwargs): |
| 40 | + # Always save critical communication and copy ops |
| 41 | + if func in _op_simple_fsdp_save_list: |
| 42 | + return CheckpointPolicy.PREFER_SAVE |
| 43 | + return CheckpointPolicy.PREFER_RECOMPUTE |
| 44 | + |
| 45 | + return _custom_policy |
| 46 | + |
| 47 | + |
| 48 | +def simplefsdp_checkpointing_context_fn(): |
| 49 | + return create_selective_checkpoint_contexts(_get_custom_simplefsdp_policy()) |
| 50 | + |
| 51 | + |
| 52 | +def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: |
| 53 | + """Apply layer selective activation checkpointing to the module. |
| 54 | +
|
| 55 | + Args: |
| 56 | + module (nn.Module): The module to apply layer selective activation checkpointing to. |
| 57 | + ac_config (ACConfig): The activation checkpointing config. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + nn.Module: The module with layer selective activation checkpointing applied. |
| 61 | + """ |
| 62 | + global _layer_sac_count |
| 63 | + _layer_sac_count += 1 |
| 64 | + ac_freq = int(ac_config.selective_ac_option) |
| 65 | + if not ac_freq or _layer_sac_count % ac_freq == 0: |
| 66 | + return ptd_checkpoint_wrapper( |
| 67 | + module, |
| 68 | + context_fn=simplefsdp_checkpointing_context_fn, |
| 69 | + preserve_rng_state=ac_config.preserve_rng_state, |
| 70 | + determinism_check=ac_config.determinism_check, |
| 71 | + early_stop=ac_config.early_stop, |
| 72 | + debug=ac_config.debug, |
| 73 | + ) |
| 74 | + else: |
| 75 | + return module |
| 76 | + |
| 77 | + |
| 78 | +def _apply_op_sac( |
| 79 | + module: nn.Module, |
| 80 | + ac_config: ACConfig, |
| 81 | + *, |
| 82 | + base_fqn: str | None = None, |
| 83 | + op_sac_save_list: set[torch._ops.OpOverload], |
| 84 | +) -> nn.Module: |
| 85 | + """Apply selective activation checkpointing to the module. |
| 86 | +
|
| 87 | + Args: |
| 88 | + module (nn.Module): The module to apply selective activation checkpointing to. |
| 89 | + ac_config (ACConfig): The activation checkpointing config. |
| 90 | + base_fqn (str, optional): The base fqn of the module. Defaults to None. |
| 91 | + op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead |
| 92 | + of recomputing. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + nn.Module: The module with selective activation checkpointing applied. |
| 96 | + """ |
| 97 | + from torch.utils.checkpoint import ( |
| 98 | + CheckpointPolicy, |
| 99 | + create_selective_checkpoint_contexts, |
| 100 | + ) |
| 101 | + |
| 102 | + mm_recompute_shapes = set() |
| 103 | + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: |
| 104 | + for module_fqn, submod in module.named_modules(): |
| 105 | + fqn = module_fqn |
| 106 | + if base_fqn is not None: |
| 107 | + fqn = f"{base_fqn}.{module_fqn}" |
| 108 | + if not any( |
| 109 | + filter_fqn in fqn |
| 110 | + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns |
| 111 | + ): |
| 112 | + continue |
| 113 | + if not isinstance(submod, nn.Linear): |
| 114 | + raise ValueError( |
| 115 | + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " |
| 116 | + f"a nn.Linear, but got: {submod}" |
| 117 | + ) |
| 118 | + out_f, in_f = submod.weight.shape |
| 119 | + mm_recompute_shapes.add((in_f, out_f)) |
| 120 | + logger.debug( |
| 121 | + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" |
| 122 | + ) |
| 123 | + |
| 124 | + def _get_custom_policy(meta): |
| 125 | + def _custom_policy(ctx, func, *args, **kwargs): |
| 126 | + if ( |
| 127 | + func == torch.ops.aten._to_copy.default |
| 128 | + and "cuda" in str(args[0].device) |
| 129 | + and "device" in kwargs |
| 130 | + and str(kwargs["device"]) == "cpu" |
| 131 | + ): |
| 132 | + return CheckpointPolicy.MUST_SAVE |
| 133 | + |
| 134 | + if func in _op_simple_fsdp_save_list: |
| 135 | + return CheckpointPolicy.PREFER_SAVE |
| 136 | + |
| 137 | + mode = "recompute" if ctx.is_recompute else "forward" |
| 138 | + mm_count_key = f"{mode}_mm_count" |
| 139 | + if func == torch.ops.aten.mm.default: |
| 140 | + if args[1].shape in mm_recompute_shapes: |
| 141 | + return CheckpointPolicy.PREFER_RECOMPUTE |
| 142 | + meta[mm_count_key] += 1 |
| 143 | + # Saves output of all compute ops, except every second mm |
| 144 | + to_save = func in op_sac_save_list and not ( |
| 145 | + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 |
| 146 | + ) |
| 147 | + return ( |
| 148 | + CheckpointPolicy.MUST_SAVE |
| 149 | + if to_save |
| 150 | + else CheckpointPolicy.PREFER_RECOMPUTE |
| 151 | + ) |
| 152 | + |
| 153 | + return _custom_policy |
| 154 | + |
| 155 | + def selective_checkpointing_context_fn(): |
| 156 | + meta = defaultdict(int) |
| 157 | + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) |
| 158 | + |
| 159 | + return ptd_checkpoint_wrapper( |
| 160 | + module, |
| 161 | + context_fn=selective_checkpointing_context_fn, |
| 162 | + preserve_rng_state=ac_config.preserve_rng_state, |
| 163 | + determinism_check=ac_config.determinism_check, |
| 164 | + early_stop=ac_config.early_stop, |
| 165 | + debug=ac_config.debug, |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: |
| 170 | + """Apply full activation checkpointing to the module. |
| 171 | +
|
| 172 | + Args: |
| 173 | + module (nn.Module): The module to apply full activation checkpointing to. |
| 174 | + ac_config (ACConfig): The activation checkpointing config. |
| 175 | +
|
| 176 | + Returns: |
| 177 | + nn.Module: The module with full activation checkpointing applied. |
| 178 | + """ |
| 179 | + return ptd_checkpoint_wrapper( |
| 180 | + module, |
| 181 | + context_fn=simplefsdp_checkpointing_context_fn, |
| 182 | + preserve_rng_state=ac_config.preserve_rng_state, |
| 183 | + determinism_check=ac_config.determinism_check, |
| 184 | + early_stop=ac_config.early_stop, |
| 185 | + debug=ac_config.debug, |
| 186 | + ) |
| 187 | + |
| 188 | + |
| 189 | +def _apply_op_sac_to_transformer_block_with_flex( |
| 190 | + module: nn.Module, |
| 191 | + ac_config: ACConfig, |
| 192 | + *, |
| 193 | + base_fqn: str | None = None, |
| 194 | + model_compile_enabled: bool = False, |
| 195 | + op_sac_save_list: set[torch._ops.OpOverload], |
| 196 | +) -> nn.Module: |
| 197 | + """Apply SAC to the transformer block that uses FlexAttention. |
| 198 | +
|
| 199 | + Args: |
| 200 | + module (nn.Module): The transformer block to apply SAC to. |
| 201 | + ac_config (ACConfig): The Activation Checkpoint config. |
| 202 | + base_fqn (str, optional): The base fqn of the module. Defaults to None. |
| 203 | + model_compile_enabled (bool): Whether model compilation is enabled. |
| 204 | + Defaults to False. |
| 205 | + op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead |
| 206 | + of recomputing. |
| 207 | +
|
| 208 | + Returns: |
| 209 | + nn.Module: The transformer block with SAC applied. |
| 210 | + """ |
| 211 | + |
| 212 | + warn_once( |
| 213 | + logger, |
| 214 | + ( |
| 215 | + "Flex Attention requires compilation for good performance.\n" |
| 216 | + "Thus, torch.compile is always used for Flex Attention, " |
| 217 | + "regardless of the compile.enable flag.\n" |
| 218 | + "However, when selective activation checkpointing (SAC) is enabled, " |
| 219 | + "torch.compile may be invalidated:\n" |
| 220 | + "1. If compile.enable is False, SAC will ignore any torch.compile " |
| 221 | + "inside the SAC region.\n" |
| 222 | + "2. If compile.enable is True but the transformer block contains an MoE module.\n\n" |
| 223 | + "For both cases, we will not wrap the entire TransformerBlock with SAC:\n" |
| 224 | + " - For case 1: SAC will be used for MoE and FeedForward modules, " |
| 225 | + "while full AC will be used for the Attention module.\n" |
| 226 | + " - For case 2: SAC will be applied to MoE and Attention modules if the block " |
| 227 | + "is sparse. But we still apply SAC to an entire dense block.\n" |
| 228 | + ), |
| 229 | + ) |
| 230 | + |
| 231 | + def wrap_submodule(name: str, full_ac: bool = False) -> None: |
| 232 | + submodule = getattr(module, name) |
| 233 | + if full_ac: |
| 234 | + submodule = _apply_full_ac(submodule, ac_config) |
| 235 | + else: |
| 236 | + submodule = _apply_op_sac( |
| 237 | + submodule, |
| 238 | + ac_config, |
| 239 | + base_fqn=f"{base_fqn}.{name}" if base_fqn else name, |
| 240 | + op_sac_save_list=op_sac_save_list, |
| 241 | + ) |
| 242 | + module.register_module(name, submodule) |
| 243 | + |
| 244 | + if hasattr(module, "moe"): |
| 245 | + wrap_submodule("moe", full_ac=False) |
| 246 | + if model_compile_enabled: |
| 247 | + wrap_submodule("attention", full_ac=False) |
| 248 | + else: |
| 249 | + wrap_submodule("attention", full_ac=True) |
| 250 | + else: |
| 251 | + if model_compile_enabled: |
| 252 | + module = _apply_op_sac( |
| 253 | + module, |
| 254 | + ac_config, |
| 255 | + base_fqn=base_fqn, |
| 256 | + op_sac_save_list=op_sac_save_list, |
| 257 | + ) |
| 258 | + else: |
| 259 | + wrap_submodule("feed_forward", full_ac=False) |
| 260 | + wrap_submodule("attention", full_ac=True) |
| 261 | + return module |
| 262 | + |
| 263 | + |
| 264 | +def _apply_ac_to_transformer_block( |
| 265 | + module: nn.Module, |
| 266 | + ac_config: ACConfig, |
| 267 | + *, |
| 268 | + base_fqn: str | None = None, |
| 269 | + model_compile_enabled: bool = False, |
| 270 | + use_flex_attn: bool = False, |
| 271 | + op_sac_save_list: set[torch._ops.OpOverload] | None = None, |
| 272 | +) -> nn.Module: |
| 273 | + valid_ac_modes = ("full", "selective") |
| 274 | + if ac_config.mode not in valid_ac_modes: |
| 275 | + raise ValueError( |
| 276 | + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" |
| 277 | + ) |
| 278 | + |
| 279 | + if ac_config.mode == "full": |
| 280 | + return _apply_full_ac(module, ac_config) |
| 281 | + |
| 282 | + assert ac_config.mode == "selective", f"{ac_config.mode}" |
| 283 | + use_op_sac = ac_config.selective_ac_option == "op" |
| 284 | + use_layer_sac = ac_config.selective_ac_option.isdigit() |
| 285 | + if not use_op_sac and not use_layer_sac: |
| 286 | + raise ValueError( |
| 287 | + f"Invalid selective AC option: {ac_config.selective_ac_option}. " |
| 288 | + f"Valid options: 'op' or a positive int representing layer frequency" |
| 289 | + ) |
| 290 | + |
| 291 | + if use_op_sac: |
| 292 | + op_sac_save_list = op_sac_save_list or set() |
| 293 | + if use_flex_attn: |
| 294 | + """ |
| 295 | + For Flex Attention, we need to apply SAC carefully to avoid invalidating |
| 296 | + torch.compile. Any torch.compile inside the SAC region will be ignored, |
| 297 | + and any torch.compile outside the SAC region will also be ignored if the |
| 298 | + SAC region contains a graph break (e.g., MoE). |
| 299 | +
|
| 300 | + TODO: remove this once SAC issues are resolved. |
| 301 | + """ |
| 302 | + return _apply_op_sac_to_transformer_block_with_flex( |
| 303 | + module, |
| 304 | + ac_config, |
| 305 | + base_fqn=base_fqn, |
| 306 | + model_compile_enabled=model_compile_enabled, |
| 307 | + op_sac_save_list=op_sac_save_list, |
| 308 | + ) |
| 309 | + else: |
| 310 | + return _apply_op_sac( |
| 311 | + module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list |
| 312 | + ) |
| 313 | + |
| 314 | + return _apply_layer_sac(module, ac_config) |
| 315 | + |
| 316 | + |
| 317 | +def apply_ac( |
| 318 | + model: nn.Module, |
| 319 | + ac_config: ACConfig, |
| 320 | + *, |
| 321 | + model_compile_enabled: bool = False, |
| 322 | + use_flex_attn: bool = False, |
| 323 | + op_sac_save_list: set[torch._ops.OpOverload] | None = None, |
| 324 | + base_folder: str = "", |
| 325 | +) -> None: |
| 326 | + """Apply activation checkpointing to the model. |
| 327 | +
|
| 328 | + Note that SAC, Flex Attention and model compilation have some conflicts. |
| 329 | + We explicitly ask the user to pass these configs to warn as the wrapping |
| 330 | + will be different. |
| 331 | +
|
| 332 | + Args: |
| 333 | + model (nn.Module): The model to apply activation checkpointing to. |
| 334 | + ac_config (ACConfig): The activation checkpointing config. |
| 335 | + model_compile_enabled (bool): Whether torch.compile is enabled for the model. |
| 336 | + use_flex_attn (bool): Whether flex attention is enabled for the model. |
| 337 | + op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead |
| 338 | + of recomputing. |
| 339 | + Returns: |
| 340 | + None |
| 341 | + """ |
| 342 | + for layer_id, transformer_block in model.layers.named_children(): |
| 343 | + transformer_block = _apply_ac_to_transformer_block( |
| 344 | + transformer_block, |
| 345 | + ac_config, |
| 346 | + base_fqn=f"layers.{layer_id}", |
| 347 | + model_compile_enabled=model_compile_enabled, |
| 348 | + use_flex_attn=use_flex_attn, |
| 349 | + op_sac_save_list=op_sac_save_list, |
| 350 | + ) |
| 351 | + model.layers.register_module(layer_id, transformer_block) |
| 352 | + |
| 353 | + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") |
0 commit comments