Skip to content

Commit 69aacc9

Browse files
committed
Use functional collective
1 parent bb33495 commit 69aacc9

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
torch.ops.aten._scaled_dot_product_flash_attention.default,
2626
torch.ops._c10d_functional.reduce_scatter_tensor.default,
2727
torch.ops._c10d_functional.all_to_all_single.default,
28+
torch.ops._c10d_functional.wait_tensor.default,
2829
# for low precision training, it's useful to always save
2930
# the result of max, since the absolute maximum is
3031
# used to compute the scaling factor for quantization.
@@ -85,15 +86,12 @@ def _apply_ac_to_transformer_block(
8586

8687
def _get_custom_policy(meta):
8788
def _custom_policy(ctx, func, *args, **kwargs):
88-
# print("custom policy called", func)
8989
if (func == torch.ops.aten._to_copy.default
9090
and "cuda" in str(args[0].device)
9191
and "device" in kwargs
9292
and str(kwargs["device"]) == "cpu"
9393
):
9494
return CheckpointPolicy.MUST_SAVE
95-
# print("to_copy", args[0].device, kwargs)
96-
# print("to_copy", args[0].device, kwargs)
9795
mode = "recompute" if ctx.is_recompute else "forward"
9896
mm_count_key = f"{mode}_mm_count"
9997
if func == torch.ops.aten.mm.default:

torchtitan/distributed/expert_parallel.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn as nn
14-
from torch.distributed._functional_collectives import all_to_all_single_autograd
14+
from torch.distributed._functional_collectives import (
15+
all_to_all_single,
16+
all_to_all_single_autograd,
17+
)
1518
from torch.distributed.tensor import (
1619
DeviceMesh,
1720
distribute_module,
@@ -146,26 +149,26 @@ def _token_dispatch(self, mod, inputs, device_mesh):
146149

147150
# generate the input splits and output splits for all-to-all
148151
with torch.no_grad():
149-
num_tokens_per_expert_group = num_tokens_per_expert.new_empty(
150-
num_tokens_per_expert.shape[0]
151-
)
152-
dist.all_to_all_single(
153-
num_tokens_per_expert_group,
152+
num_tokens_per_expert_group = all_to_all_single(
154153
num_tokens_per_expert,
154+
None,
155+
None,
155156
group=device_mesh.get_group(),
156157
)
158+
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
159+
num_tokens_per_expert_group
160+
)
157161
input_splits = (
158162
num_tokens_per_expert.view(ep_size, -1)
159163
.sum(dim=1)
160-
.to(torch.device("cpu"), non_blocking=False)
164+
.to(torch.device("cpu"), non_blocking=True)
161165
)
162166
output_splits = (
163167
num_tokens_per_expert_group.view(ep_size, -1)
164168
.sum(dim=1)
165169
.to(torch.device("cpu"), non_blocking=False)
166170
)
167171
# NOTE: this would incur a device-to-host sync
168-
# torch.cuda.current_stream().synchronize()
169172
self.input_splits = input_splits.tolist()
170173
self.output_splits = output_splits.tolist()
171174

0 commit comments

Comments
 (0)