Skip to content

Commit 0809e9a

Browse files
committed
Save _to_copy and a2a in selective AC policy
1 parent ea4989e commit 0809e9a

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def _apply_ac_to_transformer_block(
7676

7777
def _get_custom_policy(meta):
7878
def _custom_policy(ctx, func, *args, **kwargs):
79+
if (
80+
func == torch.ops.aten._to_copy.default
81+
and "cuda" in str(args[0].device)
82+
and "device" in kwargs
83+
and str(kwargs["device"]) == "cpu"
84+
):
85+
return CheckpointPolicy.MUST_SAVE
7986
mode = "recompute" if ctx.is_recompute else "forward"
8087
mm_count_key = f"{mode}_mm_count"
8188
if func == torch.ops.aten.mm.default:

torchtitan/distributed/expert_parallel.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import torch
1111
import torch.distributed as dist
1212
import torch.nn as nn
13-
from torch.distributed._functional_collectives import all_to_all_single_autograd
13+
from torch.distributed._functional_collectives import (
14+
all_to_all_single,
15+
all_to_all_single_autograd,
16+
)
1417
from torch.distributed.tensor import (
1518
DeviceMesh,
1619
distribute_module,
@@ -90,14 +93,15 @@ def _token_dispatch(self, mod, inputs, device_mesh):
9093

9194
# generate the input splits and output splits for all-to-all
9295
with torch.no_grad():
93-
num_tokens_per_expert_group = num_tokens_per_expert.new_empty(
94-
num_tokens_per_expert.shape[0]
95-
)
96-
dist.all_to_all_single(
97-
num_tokens_per_expert_group,
96+
num_tokens_per_expert_group = all_to_all_single(
9897
num_tokens_per_expert,
98+
None,
99+
None,
99100
group=device_mesh.get_group(),
100101
)
102+
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
103+
num_tokens_per_expert_group
104+
)
101105
input_splits = (
102106
num_tokens_per_expert.view(ep_size, -1)
103107
.sum(dim=1)
@@ -106,10 +110,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
106110
output_splits = (
107111
num_tokens_per_expert_group.view(ep_size, -1)
108112
.sum(dim=1)
109-
.to(torch.device("cpu"), non_blocking=True)
113+
.to(torch.device("cpu"), non_blocking=False)
110114
)
111115
# NOTE: this would incur a device-to-host sync
112-
torch.cuda.current_stream().synchronize()
113116
self.input_splits = input_splits.tolist()
114117
self.output_splits = output_splits.tolist()
115118

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
torch.ops.aten._scaled_dot_product_efficient_attention.default,
3939
torch.ops.aten._scaled_dot_product_flash_attention.default,
4040
torch.ops._c10d_functional.reduce_scatter_tensor.default,
41+
torch.ops._c10d_functional.all_to_all_single.default,
4142
# for low precision training, it's useful to always save
4243
# the result of max, since the absolute maximum is
4344
# used to compute the scaling factor for quantization.

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
torch.ops.aten._scaled_dot_product_efficient_attention.default,
3636
torch.ops.aten._scaled_dot_product_flash_attention.default,
3737
torch.ops._c10d_functional.reduce_scatter_tensor.default,
38+
torch.ops._c10d_functional.all_to_all_single.default,
3839
# for low precision training, it's useful to always save
3940
# the result of max, since the absolute maximum is
4041
# used to compute the scaling factor for quantization.

0 commit comments

Comments
 (0)