|
11 | 11 | import torch
|
12 | 12 | import torch.distributed as dist
|
13 | 13 | import torch.nn as nn
|
14 |
| -from torch.distributed._functional_collectives import all_to_all_single_autograd |
| 14 | +from torch.distributed._functional_collectives import ( |
| 15 | + all_to_all_single, |
| 16 | + all_to_all_single_autograd, |
| 17 | +) |
15 | 18 | from torch.distributed.tensor import (
|
16 | 19 | DeviceMesh,
|
17 | 20 | distribute_module,
|
@@ -146,26 +149,26 @@ def _token_dispatch(self, mod, inputs, device_mesh):
|
146 | 149 |
|
147 | 150 | # generate the input splits and output splits for all-to-all
|
148 | 151 | with torch.no_grad():
|
149 |
| - num_tokens_per_expert_group = num_tokens_per_expert.new_empty( |
150 |
| - num_tokens_per_expert.shape[0] |
151 |
| - ) |
152 |
| - dist.all_to_all_single( |
153 |
| - num_tokens_per_expert_group, |
| 152 | + num_tokens_per_expert_group = all_to_all_single( |
154 | 153 | num_tokens_per_expert,
|
| 154 | + None, |
| 155 | + None, |
155 | 156 | group=device_mesh.get_group(),
|
156 | 157 | )
|
| 158 | + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( |
| 159 | + num_tokens_per_expert_group |
| 160 | + ) |
157 | 161 | input_splits = (
|
158 | 162 | num_tokens_per_expert.view(ep_size, -1)
|
159 | 163 | .sum(dim=1)
|
160 |
| - .to(torch.device("cpu"), non_blocking=False) |
| 164 | + .to(torch.device("cpu"), non_blocking=True) |
161 | 165 | )
|
162 | 166 | output_splits = (
|
163 | 167 | num_tokens_per_expert_group.view(ep_size, -1)
|
164 | 168 | .sum(dim=1)
|
165 | 169 | .to(torch.device("cpu"), non_blocking=False)
|
166 | 170 | )
|
167 | 171 | # NOTE: this would incur a device-to-host sync
|
168 |
| - # torch.cuda.current_stream().synchronize() |
169 | 172 | self.input_splits = input_splits.tolist()
|
170 | 173 | self.output_splits = output_splits.tolist()
|
171 | 174 |
|
|
0 commit comments