Skip to content

Commit ef75299

Browse files
committed
Workaround AC HOP mutation issue when tracing token dispatch
1 parent 8659543 commit ef75299

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
6767
class ExpertParallel(ParallelStyle):
6868
def __init__(self):
6969
super().__init__()
70-
self.input_splits = None
71-
self.output_splits = None
72-
self.input_shape = None
73-
self.permuted_indices = None
7470

7571
# performing all-to-all dispatch on the input
7672
def _token_dispatch(self, mod, inputs, device_mesh):
@@ -103,14 +99,14 @@ def _token_dispatch(self, mod, inputs, device_mesh):
10399
.sum(dim=1)
104100
.to(torch.device("cpu"), non_blocking=False)
105101
)
106-
self.input_splits = input_splits.tolist()
107-
self.output_splits = output_splits.tolist()
102+
input_splits = input_splits.tolist()
103+
output_splits = output_splits.tolist()
108104

109105
# perform all-to-all
110106
routed_input = all_to_all_single_autograd(
111107
routed_input,
112-
self.output_splits,
113-
self.input_splits,
108+
output_splits,
109+
input_splits,
114110
device_mesh.get_group(),
115111
)
116112

@@ -127,15 +123,22 @@ def _token_dispatch(self, mod, inputs, device_mesh):
127123
# of GroupedExperts, as it does not need padding.
128124

129125
(
130-
self.input_shape,
126+
input_shape,
131127
routed_input,
132-
self.permuted_indices,
128+
permuted_indices,
133129
num_tokens_per_expert_group,
134130
) = _permute(
135131
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
136132
)
137133

138-
return routed_input, num_tokens_per_expert_group
134+
return (
135+
routed_input,
136+
num_tokens_per_expert_group,
137+
input_shape,
138+
permuted_indices,
139+
input_splits,
140+
output_splits,
141+
)
139142

140143
@staticmethod
141144
def _partition_fn(name, mod, device_mesh):
@@ -145,15 +148,20 @@ def _partition_fn(name, mod, device_mesh):
145148
mod.register_parameter(name, dist_param)
146149

147150
# performing all-to-all combine on the output
148-
def _token_combine(self, mod, routed_output, device_mesh):
149-
routed_output = _unpermute(
150-
routed_output, self.input_shape, self.permuted_indices
151-
)
151+
def _token_combine(self, mod, mod_outputs, device_mesh):
152+
(
153+
routed_output,
154+
input_shape,
155+
permuted_indices,
156+
input_splits,
157+
output_splits,
158+
) = mod_outputs
159+
routed_output = _unpermute(routed_output, input_shape, permuted_indices)
152160

153161
routed_output = all_to_all_single_autograd(
154162
routed_output,
155-
self.input_splits,
156-
self.output_splits,
163+
input_splits,
164+
output_splits,
157165
device_mesh.get_group(),
158166
)
159167
return routed_output
@@ -204,9 +212,9 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
204212
nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])),
205213
) # Column-wise sharding
206214

207-
def _token_combine(self, mod, routed_output, device_mesh):
215+
def _token_combine(self, mod, mod_outputs, device_mesh):
208216
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
209-
return super()._token_combine(mod, routed_output, device_mesh["ep"])
217+
return super()._token_combine(mod, mod_outputs, device_mesh["ep"])
210218

211219
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
212220
return distribute_module(

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ local_batch_size = 4
3939
seq_len = 4096
4040
max_norm = 1.0 # grad norm clipping
4141
steps = 1000
42-
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
42+
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4343

4444
[parallelism]
4545
data_parallel_replicate_degree = 1
@@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"]
6565
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6666

6767
[compile]
68-
enable=true
68+
enable = true
6969
components = ["loss"] # ["model", "loss"]
7070

7171
[quantize.linear.float8]

torchtitan/models/moe/moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ def forward(
143143
self,
144144
x: torch.Tensor,
145145
num_tokens_per_expert: torch.Tensor,
146+
input_shape,
147+
permuted_indices,
148+
input_splits,
149+
output_splits,
146150
) -> torch.Tensor:
147151
if isinstance(self.w1, DTensor):
148152
# Convert parameters from DTensors to plain Tensors, to work with
@@ -166,9 +170,11 @@ def forward(
166170
run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm)
167171
else:
168172
run_experts_fn = _run_experts_grouped_mm
169-
return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)
173+
out = run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)
170174
else:
171-
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
175+
out = _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
176+
177+
return (out, input_shape, permuted_indices, input_splits, output_splits)
172178

173179
def init_weights(self, init_std: float):
174180
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)

0 commit comments

Comments
 (0)