Skip to content

Commit 750868c

Browse files
authored
Merge pull request #60 from danielhua23/a2a_each_rank_seed
[enhance] change seed and moe const different for each rank for debug…
2 parents fc547c6 + ecd9b59 commit 750868c

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

problems/amd_distributed/all2all/reference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def generate_input(
239239
):
240240
device = torch.device(f"cuda:{rank}")
241241
gen = torch.Generator(device=device)
242-
gen.manual_seed(seed)
242+
gen.manual_seed(seed + rank)
243243

244244
cfg = MoEConfig(
245245
num_experts=num_experts,
@@ -259,7 +259,7 @@ def ref_kernel(data: input_t) -> output_t:
259259
ata = PyTorchAllToAll(cfg, rank, world_size)
260260

261261
expert_num, expert_x, expert_meta = ata.dispatch(rank_data.x, rank_data.indices)
262-
expert_y = expert_x.to(cfg.out_dtype) * 2
262+
expert_y = expert_x.to(cfg.out_dtype) * (1 + rank)
263263
y = torch.zeros(
264264
cfg.max_num_tokens,
265265
cfg.hidden_dim,

problems/amd_distributed/all2all/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def custom_kernel(data: input_t) -> output_t:
193193
ata = PyTorchAllToAll(cfg, rank, world_size)
194194

195195
expert_num, expert_x, expert_meta = ata.dispatch(rank_data.x, rank_data.indices)
196-
expert_y = expert_x.to(cfg.out_dtype) * 2
196+
expert_y = expert_x.to(cfg.out_dtype) * (1 + rank)
197197
y = torch.zeros(
198198
cfg.max_num_tokens,
199199
cfg.hidden_dim,

0 commit comments

Comments
 (0)