Skip to content

Commit 75bd4ab

Browse files
committed
Feat: works
1 parent 10015ea commit 75bd4ab

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

problems/amd_distributed.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ problems:
1010
deadline: "2025-10-14"
1111
gpus:
1212
- MI300x8
13+
- directory: amd_distributed/gemm-rs
14+
name: amd-gemm-rs
15+
deadline: "2025-10-14"
16+
gpus:
17+
- MI300x8

problems/amd_distributed/gemm-rs/reference.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55

6-
def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t:
6+
def generate_input(rank: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t:
77
"""
88
Generate random input and weights for the Gemm-ReduceScatter operation.
99
@@ -14,21 +14,22 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias:
1414
bias: Optional[torch.Tensor], # [N] or None
1515
)
1616
"""
17-
gen = torch.Generator(device='cuda')
18-
gen.manual_seed(seed + RANK)
17+
device = torch.device(f'cuda:{rank}')
18+
gen = torch.Generator(device=device)
19+
gen.manual_seed(seed + rank)
1920

2021
assert m % world_size == 0, "m must be divisible by world_size"
2122
assert k % world_size == 0, "k must be divisible by world_size"
2223
local_k = k // world_size
2324

2425
# Generate random inputs and weights
25-
input = (torch.rand((m, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01
26-
weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01
26+
input = (torch.rand((m, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
27+
weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
2728

2829
bias = None
2930
if has_bias:
3031
gen.manual_seed(seed)
31-
bias = (torch.rand((n,), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01
32+
bias = (torch.rand((n,), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
3233

3334
return (input, weight, bias)
3435

@@ -60,4 +61,12 @@ def ref_kernel(data: input_t) -> output_t:
6061
return rs_output
6162

6263

63-
check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2)
64+
def check_implementation(data: input_t, output: output_t):
65+
expected = ref_kernel(data)
66+
if output.device != expected.device:
67+
return False, f"Output device mismatch: {output.device} != {expected.device}"
68+
res = torch.allclose(output, expected, rtol=1e-2, atol=1e-2)
69+
if not res:
70+
return False, f"Output values mismatch, {output} != {expected}"
71+
72+
return True, ""

problems/amd_distributed/gemm-rs/task.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ files:
88
- {"name": "eval.py", "source": "../eval.py"}
99

1010
lang: "py"
11+
multi_gpu: true
1112

1213
description: |
1314
Implement a Gemm-ReduceScatter kernel on a single MI300X node.

0 commit comments

Comments
 (0)