Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Sep 5, 2025

Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions:

  1. If compile.enable is set to False, SAC will ignore any torch.compile calls within the SAC region.
  2. If compile.enable is True but the transformer block includes a Mixture of Experts (MoE) module.

To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 5, 2025
@fegin fegin changed the title [Don't review yet] SAC + Flex refactoring Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation Sep 5, 2025
@fegin fegin requested a review from soulitzer September 5, 2025 22:39
"torch.compile may be invalidated:\n"
"1. If compile.enable is False, SAC will ignore any torch.compile "
"inside the SAC region.\n"
"2. If compile.enable is True but the transformer block contains a MoE module.\n\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh what's wrong with this?

Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.

Copy link
Contributor Author

@fegin fegin Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh what's wrong with this?

MoE is causing a graph break, which invalidates the entire AC block compilation. The AC block will be run under eager. FlexAttention will not be compiled.

Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.

That's a good question. I originally kept SAC(TransformerBlock) for dense modules. But it turns out that the memory usage is no better than just SAC(feedforward) + SAC(attention) or even worse. Not sure why. cc., @soulitzer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If SAC(f(g(x)))'s policy saves the output of g. SAC(f)(SAC(g(x)) is probably strictly better than SAC(f(g(x))) since in eager, it allows us to clear the rematerialized activations of f before recomputing g.
In this case, the last op of the attention is matmul, so there's a chance we fall into this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer
I think we should separate forward / backward:

For backward, I roughly get that it may help if we "to clear the rematerialized activations of f before recomputing g". However, IIRC the memory peak is on forward & loss computation, not on backward, so it may not be "strictly better".

For forward, I understand that it's possible that SAC(f(g(x))) and SAC(f)(SAC(g(x)) may result in similar set of activations being saved.

the last op of the attention is matmul, so there's a chance we fall into this case.

Plus I don't think
SAC(TransformerBlock) vs. SAC(feedforward) + SAC(attention) is completely analogous to SAC(f(g(x))) vs. SAC(f)(SAC(g(x)), because in TransformerBlock we also have ffn_norm and attention_norm whose output will be saved in SAC(feedforward) + SAC(attention) but not SAC(TransformerBlock)?

Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)
IIUC the op you mentioned come from Attention.wo. But our policy is "save every other matmul", so technically we should be occasionally saving more?
Very concretely,

Since DSV3 16B has only 1 MLP layer https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/deepseek_v3/__init__.py#L87
so according to the policy here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L94, what happens could be:

SAC(TransformerBlock)

  • the input to every TransformerBlock
  • for the only MLP layer, recompute MLP.w1, save MLP.w2, recompute MLP.w3
  • for every Attention layer, save wq, recompute wkv_a, save wkv_b, recompute wo

SAC(feedforward) + SAC(attention)

  • the input to every attention (results of attention_norm)
  • for the only MLP layer, recompute MLP.w1, save MLP.w2, recompute MLP.w3
  • for every Attention layer, recompute wq, save wkv_a, recompute wkv_b, save wo
  • the input to every feedforward (results of ffn_norm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l I'm confused

Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)

Do you expect that SAC(feedforward) + SAC(attention) saves more memory or SAC(TransformerBlock) saves more memory?

From the experiment, it is SAC(feedforward) + SAC(attention). But you mentioned you expected a regression in the original comment if I do SAC(feedforward) + SAC(attention).

Copy link
Contributor Author

@fegin fegin Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the local batch size to 4 to avoid OOM.

Compiler Disabled (Feedforward is wrapped with SAC and Attention is wrapped with full AC)

step: 10  loss:  8.6712  grad_norm: 13.4708  memory: 42.29GiB(44.51%)  tps: 5,636  tflops: 86.73  mfu: 8.77%

Compiler Enabled (Feedforward is wrapped with SAC and Attention is wrapped with full AC)

step: 10  loss:  8.6271  grad_norm:  4.1878  memory: 42.16GiB(44.38%)  tps: 7,106  tflops: 109.35  mfu: 11.06%

Compiler Enabled (Attention and FeedForward are wrapped separately with SAC)

step: 10  loss:  8.4554  grad_norm:  7.7458  memory: 44.22GiB(46.55%)  tps: 7,627  tflops: 117.37  mfu: 11.87%

Compiler Enabled (Attention and FeedForward are wrapped together with full AC)

step: 10  loss:  8.6835  grad_norm:  5.3193  memory: 51.41GiB(54.11%)  tps: 8,808  tflops: 135.54  mfu: 13.70%

Compiler Enabled (Attention and FeedForward are wrapped together with SAC)

step: 10  loss:  8.8743  grad_norm:  8.6252  memory: 52.98GiB(55.77%)  tps: 8,758  tflops: 134.76  mfu: 13.63%

Compiler Enabled (No AC applied to the TransformerBlock with FeedForward)

step: 10  loss:  8.8352  grad_norm:  3.7476  memory: 52.56GiB(55.32%)  tps: 9,802  tflops: 150.83  mfu: 15.25%

All of these settings have the same wrapping around TransformerBlock with MoE. Can the myth be from compiler + AC/SAC interaction? @soulitzer

Copy link
Contributor

@soulitzer soulitzer Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the data. I do notice that the "wrapped together" cases are significantly more memory.

FlexAttention is compiled because the outer SAC is compiled.

So this is without MoEs, i.e., you set num dense layer to the number of total layers, and you are able to compile with fullgraph=True?

If there are mostly non-Dense layers, then I'd still imagine that the graph break in the MoEs would prevent Flex from being compiled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may miscommunicate. Let me put a summary of how these experiments were done.

  1. This is a 16B model, with one dense layer. All the configuration changes are applied to that dense block only.

  2. If a TransformerBlock has MoE, aka a sparse block, its Attention module is always wrapped separately. So the graph break from MoE doesn't prevent Flex from being compiled.

  3. The different configurations are applied to the only dense block, the TransformerBlock that has FeedForward but not MoE. There should be no graph breaks in this dense block, so even wrapping the FeedForward with Attention together with SAC should compile Flex correctly.

  4. If you check the last configuration in the experiment, I didn't apply SAC nor AC to the dense block. Its memory usage is very similar to other cases where FeedForward and Attention are wrapped together.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.

But it seems not the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, thanks for explaining. I guess I'm pretty confused by these results. Going to think about it more tomorrow.

Some thoughts so far:

  • in compile unless the ACs are adjacent the inputs to AC aren't force saved (not sure how I feel about that tbh).
  • Attention (4 mms) happens before MLP (3 mms), so it would be (recompute, save, recompute ,save) for the mm in the first SAC on the Attention (recompute, save, recompute) on the second SAC on the MLP whether or not there is one or two SAC region! (As a test we can change the SAC policy to save all matmuls for example.) I was thinking about
  • Partitioner itself will do some recompute, so maybe its too surprising that SAC results can be same as no AC at all, e.g. perhaps RMSNorm is fusible and thus gets to be recomputed by default, and it also decides to save mms since they are compute intensive.
  • Quite confusingly, wrap TransformerBlock with full AC isn't using the least amount of memory among all these options.


Args:
model (nn.Module): The model to apply activation checkpointing to.
ac_config (ActivationCheckpoint): The activation checkpointing config.
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
use_flex_attn (bool): Whether flex attention is enabled for the model.
save_list (set[torch._ops.OpOverload]): The list of ops to save when selective
activation checkpointing is used.
save_list (set[torch._ops.OpOverload]): The list of ops to save instead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe should name it explicitly to denote it's only used by per op sac

if (m := getattr(module, name, None)) is not None:
module.register_module(
name,
_apply_op_sac(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MoE + compiled enabled, IIUC the graph break will still cause SAC to run in eager. It's just there's no FlexAttention, so nothing bad happens. Is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

"torch.compile may be invalidated:\n"
"1. If compile.enable is False, SAC will ignore any torch.compile "
"inside the SAC region.\n"
"2. If compile.enable is True but the transformer block contains a MoE module.\n\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.

But it seems not the case?

save_list=save_list,
),
)
if model_compile_enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user perspective, I'd hope compile on/off doesn't simultaneously change other settings like this.
Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?
O/w in extreme cases it's possible that with compile, we are seeing slower throughput (due to more recomputation), but it'll be very unintuitive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than >SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a >regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?

Yes, this needs more investigation. It's also not clear to me why the result is counterintuitive.

Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?

Yes, this is a valid option and I have tested that option. The result doesn't show any downsides and the logic is simpler and UX is much more clear. But I have to re-verify it again as after #1672, the memory usage is different.

@fegin fegin force-pushed the chienchin/flex_sac_hack2 branch from 8fed32a to a2570f2 Compare September 10, 2025 05:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants