-
Notifications
You must be signed in to change notification settings - Fork 522
Save _to_copy and a2a in selective AC policy #1672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,11 @@ | |
from typing import Callable, Literal | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from torch.distributed._functional_collectives import all_to_all_single_autograd | ||
from torch.distributed._functional_collectives import ( | ||
all_to_all_single, | ||
all_to_all_single_autograd, | ||
) | ||
from torch.distributed.tensor import ( | ||
DeviceMesh, | ||
distribute_module, | ||
|
@@ -90,26 +92,28 @@ def _token_dispatch(self, mod, inputs, device_mesh): | |
|
||
# generate the input splits and output splits for all-to-all | ||
with torch.no_grad(): | ||
num_tokens_per_expert_group = num_tokens_per_expert.new_empty( | ||
num_tokens_per_expert.shape[0] | ||
) | ||
dist.all_to_all_single( | ||
num_tokens_per_expert_group, | ||
num_tokens_per_expert_group = all_to_all_single( | ||
num_tokens_per_expert, | ||
None, | ||
None, | ||
group=device_mesh.get_group(), | ||
) | ||
# Need to wait explicitly because it is used by a triton kernel later | ||
# which doesn't realize that AsyncCollectiveTensor needs unwrapping | ||
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
num_tokens_per_expert_group | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need an explicit wait because num_tokens_per_expert_group gets used by a triton kernel, which doesn't realize that AsyncCollectiveTensor needs to be unwrapped. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you make this a comment in the code? I think it's very helpful. |
||
input_splits = ( | ||
num_tokens_per_expert.view(ep_size, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
) | ||
# NOTE: this would incur a device-to-host sync | ||
output_splits = ( | ||
num_tokens_per_expert_group.view(ep_size, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
.to(torch.device("cpu"), non_blocking=False) | ||
) | ||
# NOTE: this would incur a device-to-host sync | ||
torch.cuda.current_stream().synchronize() | ||
self.input_splits = input_splits.tolist() | ||
self.output_splits = output_splits.tolist() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm did this take effect? I would guess we don't need to do any d2h sync in backward anymore, but in the traces I'm still seeing them in backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm there is still cudaStreamSync in FlexAttentionBackward but it is expected since SAC only takes effect for the replay of the forward. Is there another place where you see it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to @drisspg this H2D of
FlexAttentionBackward
is from the eager implementation. #1683 will fix the FlexAttention compilation issue.