Skip to content

Commit 55a243d

Browse files
authored
Merge pull request #58 from wangxunx/gemm-reducescatter
add reference impl of gemm-reducescatter
2 parents 2963e52 + 0d0ca84 commit 55a243d

File tree

5 files changed

+204
-0
lines changed

5 files changed

+204
-0
lines changed

problems/amd_distributed.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ problems:
1010
deadline: "2025-10-14"
1111
gpus:
1212
- MI300x8
13+
- directory: amd_distributed/gemm-rs
14+
name: amd-gemm-rs
15+
deadline: "2025-10-14"
16+
gpus:
17+
- MI300x8
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from task import input_t, output_t
2+
import torch
3+
4+
5+
def generate_input(rank: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t:
6+
"""
7+
Generate random input and weights for the Gemm-ReduceScatter operation.
8+
9+
Returns:
10+
Tuple of (
11+
input: torch.Tensor, # [M, local_K]
12+
weight: torch.Tensor, # [N, local_K]
13+
bias: Optional[torch.Tensor], # [N] or None
14+
)
15+
"""
16+
device = torch.device(f'cuda:{rank}')
17+
gen = torch.Generator(device=device)
18+
gen.manual_seed(seed + rank)
19+
20+
assert m % world_size == 0, "m must be divisible by world_size"
21+
assert k % world_size == 0, "k must be divisible by world_size"
22+
local_k = k // world_size
23+
24+
# Generate random inputs and weights
25+
input = (torch.rand((m, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
26+
weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
27+
28+
bias = None
29+
if has_bias:
30+
gen.manual_seed(seed)
31+
bias = (torch.rand((n,), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
32+
33+
return (input, weight, bias)
34+
35+
36+
def ref_kernel(data: input_t) -> output_t:
37+
"""
38+
Reference kernel for Gemm-ReduceScatter operation.
39+
40+
Args:
41+
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor])
42+
- input: Local input tensor of shape [M, local_K].
43+
- weight: Weight tensor of shape [N, local_K].
44+
- bias: Optional bias tensor of shape [N] or None.
45+
Returns:
46+
Tuple containing:
47+
- output: Resulting tensor of shape [M // world_size, N].
48+
"""
49+
input, weight, bias = data
50+
M, local_K = input.shape
51+
N = weight.shape[0]
52+
world_size = torch.distributed.get_world_size()
53+
# matmul
54+
output = torch.matmul(input, weight.T)
55+
if bias is not None:
56+
output = output + bias
57+
# reduce scatter
58+
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
59+
torch.distributed.reduce_scatter_tensor(rs_output, output)
60+
return rs_output
61+
62+
63+
def check_implementation(data: input_t, output: output_t):
64+
expected = ref_kernel(data)
65+
if output.device != expected.device:
66+
return False, f"Output device mismatch: {output.device} != {expected.device}"
67+
res = torch.allclose(output, expected, rtol=1e-2, atol=1e-2)
68+
if not res:
69+
return False, f"Output values mismatch, {output} != {expected}"
70+
71+
return True, ""
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from task import input_t, output_t
2+
import torch
3+
4+
5+
def custom_kernel(data: input_t) -> output_t:
6+
"""
7+
Reference kernel for Gemm-ReduceScatter operation.
8+
9+
Args:
10+
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor])
11+
- input: Local input tensor of shape [M, local_K].
12+
- weight: Weight tensor of shape [N, local_K].
13+
- bias: Optional bias tensor of shape [N] or None.
14+
Returns:
15+
Tuple containing:
16+
- output: Resulting tensor of shape [M // world_size, N].
17+
"""
18+
input, weight, bias = data
19+
M, local_K = input.shape
20+
N = weight.shape[0]
21+
world_size = torch.distributed.get_world_size()
22+
# matmul
23+
output = torch.matmul(input, weight.T)
24+
if bias is not None:
25+
output = output + bias
26+
# reduce scatter
27+
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
28+
torch.distributed.reduce_scatter_tensor(rs_output, output)
29+
return rs_output
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import TypedDict, TypeVar, Tuple, Optional
2+
import torch
3+
4+
input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]])
5+
output_t = TypeVar("output_t", bound=torch.Tensor)
6+
7+
8+
class TestSpec(TypedDict):
9+
world_size: int
10+
m: int
11+
n: int
12+
k: int
13+
has_bias: bool
14+
seed: int
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# name: gemm-rs
2+
3+
files:
4+
- {"name": "submission.py", "source": "@SUBMISSION@"}
5+
- {"name": "task.py", "source": "task.py"}
6+
- {"name": "utils.py", "source": "../utils.py"}
7+
- {"name": "reference.py", "source": "reference.py"}
8+
- {"name": "eval.py", "source": "../eval.py"}
9+
10+
lang: "py"
11+
multi_gpu: true
12+
13+
description: |
14+
Implement a Gemm-ReduceScatter kernel on a single MI300X node.
15+
16+
Gemm-ReduceScatter is a technique that combines the ReduceScatter
17+
communication pattern with General Matrix Multiplication (GEMM) to optimize
18+
the performance of transformer models on GPUs. It is particularly useful for
19+
handling large models that exceed the memory capacity of a single GPU by
20+
distributing the model across multiple GPUs and efficiently scattering the
21+
results of matrix multiplications.
22+
23+
Your task:
24+
- Implement the Gemm-RS kernel to perform matrix multiplications in a
25+
distributed manner, leveraging the ReduceScatter operation to distribute
26+
data across multiple GPUs.
27+
- Ensure that the implementation is optimized for the MI300X architecture,
28+
taking advantage of its specific hardware features for maximum performance.
29+
30+
Input:
31+
- `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor,
32+
bias: Optional, None or torch.Tensor)
33+
- input: Local input tensor of shape [M, local_K].
34+
- weight: Weight tensor of shape [N, local_K].
35+
- bias: bias tensor of shape [N] or None.
36+
37+
Output:
38+
- Tuple containing:
39+
- output: Resulting tensor of shape [M // world_size, N]
40+
41+
The ranking criteria is the geometric mean of the benchmark results.
42+
43+
For the grand price, your kernel will be evaluated against the speed of light
44+
analysis and AMD implementations, the solution closest to the speed of light
45+
and AMD implementations will be awarded the grand price.
46+
```
47+
The speed of light analysis is:
48+
m n k has_bias time[us]
49+
64 7168 18432 False 6.46
50+
512 4096 12288 True 8.19
51+
2048 2880 2880 True 23.04
52+
4096 4096 4096 False 65.54
53+
8192 4096 14336 True 131.07
54+
8192 8192 29568 False 379.43
55+
```
56+
config:
57+
main: "eval.py"
58+
59+
templates:
60+
Python: "submission.py"
61+
62+
ranking_by: "geom"
63+
ranked_timeout: 360 # just in case
64+
65+
tests:
66+
- {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035}
67+
- {"world_size": 8, "m": 64, "n": 3584, "k": 14336, "has_bias": True, "seed": 13}
68+
- {"world_size": 8, "m": 512, "n": 3584, "k": 14336, "has_bias": True, "seed": 4297}
69+
- {"world_size": 8, "m": 512, "n": 4608, "k": 36864, "has_bias": False, "seed": 1597}
70+
- {"world_size": 8, "m": 2048, "n": 4096, "k": 7168, "has_bias": False, "seed": 716}
71+
- {"world_size": 8, "m": 2048, "n": 8192, "k": 30720, "has_bias": False, "seed": 20201}
72+
- {"world_size": 8, "m": 4096, "n": 2880, "k": 2880, "has_bias": True, "seed": 136}
73+
- {"world_size": 8, "m": 4096, "n": 8192, "k": 2048, "has_bias": True, "seed": 138}
74+
- {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 748}
75+
- {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422}
76+
- {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536}
77+
78+
79+
benchmarks:
80+
- {"world_size": 8, "m": 64, "n": 7168, "k": 18432, "has_bias": False, "seed": 1234}
81+
- {"world_size": 8, "m": 512, "n": 4096, "k": 12288, "has_bias": True, "seed": 663}
82+
- {"world_size": 8, "m": 2048, "n": 2880, "k": 2880, "has_bias": True, "seed": 166}
83+
- {"world_size": 8, "m": 4096, "n": 4096, "k": 4096, "has_bias": False, "seed": 1371}
84+
- {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168}
85+
- {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 42}

0 commit comments

Comments
 (0)