Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions problems/amd_distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ problems:
deadline: "2025-10-14"
gpus:
- MI300x8
- directory: amd_distributed/gemm-rs
name: amd-gemm-rs
deadline: "2025-10-14"
gpus:
- MI300x8
71 changes: 71 additions & 0 deletions problems/amd_distributed/gemm-rs/reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from task import input_t, output_t
import torch


def generate_input(rank: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t:
"""
Generate random input and weights for the Gemm-ReduceScatter operation.

Returns:
Tuple of (
input: torch.Tensor, # [M, local_K]
weight: torch.Tensor, # [N, local_K]
bias: Optional[torch.Tensor], # [N] or None
)
"""
device = torch.device(f'cuda:{rank}')
gen = torch.Generator(device=device)
gen.manual_seed(seed + rank)

assert m % world_size == 0, "m must be divisible by world_size"
assert k % world_size == 0, "k must be divisible by world_size"
local_k = k // world_size

# Generate random inputs and weights
input = (torch.rand((m, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01
weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01

bias = None
if has_bias:
gen.manual_seed(seed)
bias = (torch.rand((n,), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01

return (input, weight, bias)


def ref_kernel(data: input_t) -> output_t:
"""
Reference kernel for Gemm-ReduceScatter operation.

Args:
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor])
- input: Local input tensor of shape [M, local_K].
- weight: Weight tensor of shape [N, local_K].
- bias: Optional bias tensor of shape [N] or None.
Returns:
Tuple containing:
- output: Resulting tensor of shape [M // world_size, N].
"""
input, weight, bias = data
M, local_K = input.shape
N = weight.shape[0]
world_size = torch.distributed.get_world_size()
# matmul
output = torch.matmul(input, weight.T)
if bias is not None:
output = output + bias
# reduce scatter
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, output)
return rs_output


def check_implementation(data: input_t, output: output_t):
expected = ref_kernel(data)
if output.device != expected.device:
return False, f"Output device mismatch: {output.device} != {expected.device}"
res = torch.allclose(output, expected, rtol=1e-2, atol=1e-2)
if not res:
return False, f"Output values mismatch, {output} != {expected}"

return True, ""
29 changes: 29 additions & 0 deletions problems/amd_distributed/gemm-rs/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from task import input_t, output_t
import torch


def custom_kernel(data: input_t) -> output_t:
"""
Reference kernel for Gemm-ReduceScatter operation.

Args:
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor])
- input: Local input tensor of shape [M, local_K].
- weight: Weight tensor of shape [N, local_K].
- bias: Optional bias tensor of shape [N] or None.
Returns:
Tuple containing:
- output: Resulting tensor of shape [M // world_size, N].
"""
input, weight, bias = data
M, local_K = input.shape
N = weight.shape[0]
world_size = torch.distributed.get_world_size()
# matmul
output = torch.matmul(input, weight.T)
if bias is not None:
output = output + bias
# reduce scatter
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, output)
return rs_output
14 changes: 14 additions & 0 deletions problems/amd_distributed/gemm-rs/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import TypedDict, TypeVar, Tuple, Optional
import torch

input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]])
output_t = TypeVar("output_t", bound=torch.Tensor)


class TestSpec(TypedDict):
world_size: int
m: int
n: int
k: int
has_bias: bool
seed: int
85 changes: 85 additions & 0 deletions problems/amd_distributed/gemm-rs/task.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# name: gemm-rs

files:
- {"name": "submission.py", "source": "@SUBMISSION@"}
- {"name": "task.py", "source": "task.py"}
- {"name": "utils.py", "source": "../utils.py"}
- {"name": "reference.py", "source": "reference.py"}
- {"name": "eval.py", "source": "../eval.py"}

lang: "py"
multi_gpu: true

description: |
Implement a Gemm-ReduceScatter kernel on a single MI300X node.

Gemm-ReduceScatter is a technique that combines the ReduceScatter
communication pattern with General Matrix Multiplication (GEMM) to optimize
the performance of transformer models on GPUs. It is particularly useful for
handling large models that exceed the memory capacity of a single GPU by
distributing the model across multiple GPUs and efficiently scattering the
results of matrix multiplications.

Your task:
- Implement the Gemm-RS kernel to perform matrix multiplications in a
distributed manner, leveraging the ReduceScatter operation to distribute
data across multiple GPUs.
- Ensure that the implementation is optimized for the MI300X architecture,
taking advantage of its specific hardware features for maximum performance.

Input:
- `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor,
bias: Optional, None or torch.Tensor)
- input: Local input tensor of shape [M, local_K].
- weight: Weight tensor of shape [N, local_K].
- bias: bias tensor of shape [N] or None.

Output:
- Tuple containing:
- output: Resulting tensor of shape [M // world_size, N]

The ranking criteria is the geometric mean of the benchmark results.

For the grand price, your kernel will be evaluated against the speed of light
analysis and AMD implementations, the solution closest to the speed of light
and AMD implementations will be awarded the grand price.
```
The speed of light analysis is:
m n k has_bias time[us]
64 7168 18432 False 6.46
512 4096 12288 True 8.19
2048 2880 2880 True 23.04
4096 4096 4096 False 65.54
8192 4096 14336 True 131.07
8192 8192 29568 False 379.43
```
config:
main: "eval.py"

templates:
Python: "submission.py"

ranking_by: "geom"
ranked_timeout: 360 # just in case

tests:
- {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035}
- {"world_size": 8, "m": 64, "n": 3584, "k": 14336, "has_bias": True, "seed": 13}
- {"world_size": 8, "m": 512, "n": 3584, "k": 14336, "has_bias": True, "seed": 4297}
- {"world_size": 8, "m": 512, "n": 4608, "k": 36864, "has_bias": False, "seed": 1597}
- {"world_size": 8, "m": 2048, "n": 4096, "k": 7168, "has_bias": False, "seed": 716}
- {"world_size": 8, "m": 2048, "n": 8192, "k": 30720, "has_bias": False, "seed": 20201}
- {"world_size": 8, "m": 4096, "n": 2880, "k": 2880, "has_bias": True, "seed": 136}
- {"world_size": 8, "m": 4096, "n": 8192, "k": 2048, "has_bias": True, "seed": 138}
- {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 748}
- {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422}
- {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536}


benchmarks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add various different shapes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shapes definition updated.

- {"world_size": 8, "m": 64, "n": 7168, "k": 18432, "has_bias": False, "seed": 1234}
- {"world_size": 8, "m": 512, "n": 4096, "k": 12288, "has_bias": True, "seed": 663}
- {"world_size": 8, "m": 2048, "n": 2880, "k": 2880, "has_bias": True, "seed": 166}
- {"world_size": 8, "m": 4096, "n": 4096, "k": 4096, "has_bias": False, "seed": 1371}
- {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168}
- {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 42}