-
Notifications
You must be signed in to change notification settings - Fork 509
Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation #1683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
"torch.compile may be invalidated:\n" | ||
"1. If compile.enable is False, SAC will ignore any torch.compile " | ||
"inside the SAC region.\n" | ||
"2. If compile.enable is True but the transformer block contains a MoE module.\n\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh what's wrong with this?
Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh what's wrong with this?
MoE is causing a graph break, which invalidates the entire AC block compilation. The AC block will be run under eager. FlexAttention will not be compiled.
Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.
That's a good question. I originally kept SAC(TransformerBlock) for dense modules. But it turns out that the memory usage is no better than just SAC(feedforward) + SAC(attention) or even worse. Not sure why. cc., @soulitzer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If SAC(f(g(x)))'s policy saves the output of g. SAC(f)(SAC(g(x)) is probably strictly better than SAC(f(g(x))) since in eager, it allows us to clear the rematerialized activations of f before recomputing g.
In this case, the last op of the attention is matmul, so there's a chance we fall into this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soulitzer
I think we should separate forward / backward:
For backward, I roughly get that it may help if we "to clear the rematerialized activations of f before recomputing g". However, IIRC the memory peak is on forward & loss computation, not on backward, so it may not be "strictly better".
For forward, I understand that it's possible that SAC(f(g(x))) and SAC(f)(SAC(g(x)) may result in similar set of activations being saved.
the last op of the attention is matmul, so there's a chance we fall into this case.
Plus I don't think
SAC(TransformerBlock) vs. SAC(feedforward) + SAC(attention) is completely analogous to SAC(f(g(x))) vs. SAC(f)(SAC(g(x)), because in TransformerBlock we also have ffn_norm
and attention_norm
whose output will be saved in SAC(feedforward) + SAC(attention) but not SAC(TransformerBlock)?
Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)
IIUC the op you mentioned come from Attention.wo
. But our policy is "save every other matmul", so technically we should be occasionally saving more?
Very concretely,
- MLP has 3 matmul, w1, w2, w3
- Attention (DSV3 16B) has 4 matmul, wq, wkv_a, wkv_b, wo
- MoE has 3 grouped_mm, and 1 matmul from router.gate, but for this one it's not in the "save every other matmul" regime due to https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L90
Since DSV3 16B has only 1 MLP layer https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/deepseek_v3/__init__.py#L87
so according to the policy here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L94, what happens could be:
SAC(TransformerBlock)
- the input to every
TransformerBlock
- for the only MLP layer, recompute
MLP.w1
, saveMLP.w2
, recomputeMLP.w3
- for every Attention layer, save
wq
, recomputewkv_a
, savewkv_b
, recomputewo
SAC(feedforward) + SAC(attention)
- the input to every attention (results of
attention_norm
) - for the only MLP layer, recompute
MLP.w1
, saveMLP.w2
, recomputeMLP.w3
- for every Attention layer, recompute
wq
, savewkv_a
, recomputewkv_b
, savewo
- the input to every feedforward (results of
ffn_norm
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l I'm confused
Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)
Do you expect that SAC(feedforward) + SAC(attention)
saves more memory or SAC(TransformerBlock)
saves more memory?
From the experiment, it is SAC(feedforward) + SAC(attention)
. But you mentioned you expected a regression in the original comment if I do SAC(feedforward) + SAC(attention)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change the local batch size to 4 to avoid OOM.
Compiler Disabled (Feedforward is wrapped with SAC and Attention is wrapped with full AC)
step: 10 loss: 8.6712 grad_norm: 13.4708 memory: 42.29GiB(44.51%) tps: 5,636 tflops: 86.73 mfu: 8.77%
Compiler Enabled (Feedforward is wrapped with SAC and Attention is wrapped with full AC)
step: 10 loss: 8.6271 grad_norm: 4.1878 memory: 42.16GiB(44.38%) tps: 7,106 tflops: 109.35 mfu: 11.06%
Compiler Enabled (Attention and FeedForward are wrapped separately with SAC)
step: 10 loss: 8.4554 grad_norm: 7.7458 memory: 44.22GiB(46.55%) tps: 7,627 tflops: 117.37 mfu: 11.87%
Compiler Enabled (Attention and FeedForward are wrapped together with full AC)
step: 10 loss: 8.6835 grad_norm: 5.3193 memory: 51.41GiB(54.11%) tps: 8,808 tflops: 135.54 mfu: 13.70%
Compiler Enabled (Attention and FeedForward are wrapped together with SAC)
step: 10 loss: 8.8743 grad_norm: 8.6252 memory: 52.98GiB(55.77%) tps: 8,758 tflops: 134.76 mfu: 13.63%
Compiler Enabled (No AC applied to the TransformerBlock with FeedForward)
step: 10 loss: 8.8352 grad_norm: 3.7476 memory: 52.56GiB(55.32%) tps: 9,802 tflops: 150.83 mfu: 15.25%
All of these settings have the same wrapping around TransformerBlock with MoE. Can the myth be from compiler + AC/SAC interaction? @soulitzer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the data. I do notice that the "wrapped together" cases are significantly more memory.
FlexAttention is compiled because the outer SAC is compiled.
So this is without MoEs, i.e., you set num dense layer to the number of total layers, and you are able to compile with fullgraph=True?
If there are mostly non-Dense layers, then I'd still imagine that the graph break in the MoEs would prevent Flex from being compiled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may miscommunicate. Let me put a summary of how these experiments were done.
-
This is a 16B model, with one dense layer. All the configuration changes are applied to that dense block only.
-
If a TransformerBlock has
MoE
, aka a sparse block, its Attention module is always wrapped separately. So the graph break from MoE doesn't prevent Flex from being compiled. -
The different configurations are applied to the only dense block, the TransformerBlock that has
FeedForward
but notMoE
. There should be no graph breaks in this dense block, so even wrapping theFeedForward
withAttention
together with SAC should compile Flex correctly. -
If you check the last configuration in the experiment, I didn't apply SAC nor AC to the dense block. Its memory usage is very similar to other cases where
FeedForward
andAttention
are wrapped together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, thanks for explaining. I guess I'm pretty confused by these results. Going to think about it more tomorrow.
Some thoughts so far:
- in compile unless the ACs are adjacent the inputs to AC aren't force saved (not sure how I feel about that tbh).
- Attention (4 mms) happens before MLP (3 mms), so it would be (recompute, save, recompute ,save) for the mm in the first SAC on the Attention (recompute, save, recompute) on the second SAC on the MLP whether or not there is one or two SAC region! (As a test we can change the SAC policy to save all matmuls for example.) I was thinking about
- Partitioner itself will do some recompute, so maybe its too surprising that SAC results can be same as no AC at all, e.g. perhaps RMSNorm is fusible and thus gets to be recomputed by default, and it also decides to save mms since they are compute intensive.
- Quite confusingly, wrap TransformerBlock with full AC isn't using the least amount of memory among all these options.
|
||
Args: | ||
model (nn.Module): The model to apply activation checkpointing to. | ||
ac_config (ActivationCheckpoint): The activation checkpointing config. | ||
model_compile_enabled (bool): Whether torch.compile is enabled for the model. | ||
use_flex_attn (bool): Whether flex attention is enabled for the model. | ||
save_list (set[torch._ops.OpOverload]): The list of ops to save when selective | ||
activation checkpointing is used. | ||
save_list (set[torch._ops.OpOverload]): The list of ops to save instead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe should name it explicitly to denote it's only used by per op sac
if (m := getattr(module, name, None)) is not None: | ||
module.register_module( | ||
name, | ||
_apply_op_sac( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For MoE + compiled enabled, IIUC the graph break will still cause SAC to run in eager. It's just there's no FlexAttention, so nothing bad happens. Is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
"torch.compile may be invalidated:\n" | ||
"1. If compile.enable is False, SAC will ignore any torch.compile " | ||
"inside the SAC region.\n" | ||
"2. If compile.enable is True but the transformer block contains a MoE module.\n\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?
save_list=save_list, | ||
), | ||
) | ||
if model_compile_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a user perspective, I'd hope compile on/off doesn't simultaneously change other settings like this.
Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?
O/w in extreme cases it's possible that with compile, we are seeing slower throughput (due to more recomputation), but it'll be very unintuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than >SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a >regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?
Yes, this needs more investigation. It's also not clear to me why the result is counterintuitive.
Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?
Yes, this is a valid option and I have tested that option. The result doesn't show any downsides and the logic is simpler and UX is much more clear. But I have to re-verify it again as after #1672, the memory usage is different.
8fed32a
to
a2570f2
Compare
Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions:
To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False.