-
Notifications
You must be signed in to change notification settings - Fork 520
Save _to_copy and a2a in selective AC policy #1672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
num_tokens_per_expert.view(ep_size, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
.to(torch.device("cpu"), non_blocking=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh could you remind me of the reason we have to use non_blocking=False
?
I think it may not matter too much as this two d2h syncs are adjacent to each other.
If we have to do this, we can remove the non_blocking
arg as False
is the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is a good point, we can avoid blocking for the first .to(), although yeah I don't think it changed tps very much
torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
torch.ops._c10d_functional.all_to_all_single.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you confirm that this saves both:
- the
dist.all_to_all_single
to obtain routing info - the actual
all_to_all_single_autograd
to route tokens
I think ideally we'd like both to be saved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, dist.all_to_all_single
is actually a different op.
Do you know why we use two different all-to-alls here? I don't think dist.all_to_all_single works with SAC, it
mutates a "output" tensor that the user provides and returns a work object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, there may not be strong reason to.
Could you try the fun col version? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L445
If it works we can switch to this one, and hopefully the AC policy would capture both, because underlying the same torch.ops._c10d_functional.all_to_all_single.default
gets called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use the fun col version
Just a FYI, @soulitzer , #1675 conflicts with this PR. |
) | ||
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
num_tokens_per_expert_group | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need an explicit wait because num_tokens_per_expert_group gets used by a triton kernel, which doesn't realize that AsyncCollectiveTensor needs to be unwrapped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make this a comment in the code? I think it's very helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had one more question regarding saving the results of wait_tensor
.
Also it would be great if you could share some profiler traces in PR summary.
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
torch.ops._c10d_functional.all_to_all_single.default, | ||
torch.ops._c10d_functional.wait_tensor.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this has any side effect, as in the mapping from all collectives and wait
is many-to-one. In particular,
- Would this line save all the communication results, not only from a2a but also e.g. TP all-gather?
- Would not having this line save none of the communication results? I.e. did the
torch.ops._c10d_functional.reduce_scatter_tensor.default,
line take effect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there should be side effect, unless there are other wait_tensors being explicitly called.
-
Ordinarily, the AsyncCollectiveTensor triggers the wait before executing the op in its torch dispatch, so it would actually be hidden from SAC (user modes execute before user subclasses unwrap). SAC should only be able to see / save the wait if we're calling it explicitly here.
-
The lines for reduce scatter, etc will save AsyncCollectiveTensor, and in the original forward, when wait happens via the subclasses's torch dispatch, the wait result should be cached onto the AsyncCollectiveTensor, so that a second wait should not be triggered during recompute.
That being said. I'm not actually entirely sure what happens when you executing wait explicitly on an AsyncCollectiveTensor again even though the collective has already been waited on. Checking again, removing it doesn't seem to affect tps, so I think I will remove it.
Added links to some profiler traces in the summary. From staring at the traces, saving the wait_tensor reduces the cpu overhead from 200us to 40us, but doesn't really seem to affect tps, so removing it to minimize risk of side effects. |
) | ||
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
num_tokens_per_expert_group | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make this a comment in the code? I think it's very helpful.
.to(torch.device("cpu"), non_blocking=True) | ||
.to(torch.device("cpu"), non_blocking=False) | ||
) | ||
# NOTE: this would incur a device-to-host sync |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please move this note to the actual blocking call above
and "device" in kwargs | ||
and str(kwargs["device"]) == "cpu" | ||
): | ||
return CheckpointPolicy.MUST_SAVE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm did this take effect? I would guess we don't need to do any d2h sync in backward anymore, but in the traces I'm still seeing them in backward.
(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch
(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm there is still cudaStreamSync in FlexAttentionBackward but it is expected since SAC only takes effect for the replay of the forward. Is there another place where you see it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great!
Before (not saving a2a and to_copy)
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_recompute.json.gz&bucket=pytorch
After (saving a2a and to_copy)
(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch
(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch