Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions problems/amd_distributed/gemm-rs/reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from utils import make_match_reference
from task import input_t, output_t
import torch


def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int) -> input_t:
"""
Generate random input and weights for the AG-GEMM operation.

Returns:
Tuple of (
input: torch.Tensor, # [M, local_K]
weight: torch.Tensor, # [N, local_K]
transposed_weight: bool, # Whether the weight is transposed
bias: Optional[torch.Tensor], # [N] or None
)
"""
gen = torch.Generator(device='cuda')
gen.manual_seed(seed + RANK)

assert m % world_size == 0, "m must be divisible by world_size"
assert k % world_size == 0, "k must be divisible by world_size"
local_k = k // world_size

# Generate random inputs and weights
input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01
weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01

return (input, weight, False, None)


def ref_kernel(data: input_t) -> output_t:
"""
Reference kernel for Gemm-ReduceScatter operation.

Args:
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool,
bias: Optional[torch.Tensor])
- input: Local input tensor of shape [M, local_K].
- weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True.
- transposed_weight: Whether the weight is transposed.
- bias: Optional bias tensor of shape [N] or None.
Returns:
Tuple containing:
- output: Resulting tensor of shape [M // world_size, N].
"""
input, weight, transposed_weight, bias = data
M, local_K = input.shape
if not transposed_weight:
weight = weight.T
N = weight.shape[1]
world_size = torch.distributed.get_world_size()
# matmul
output = torch.matmul(input, weight)
if bias is not None:
output = output + bias
# reduce scatter
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, output)
return rs_output


check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2)
33 changes: 33 additions & 0 deletions problems/amd_distributed/gemm-rs/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from task import input_t, output_t
import torch


def custom_kernel(data: input_t) -> output_t:
"""
Reference kernel for Gemm-ReduceScatter operation.

Args:
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool,
bias: Optional[torch.Tensor])
- input: Local input tensor of shape [M, local_K].
- weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True.
- transposed_weight: Whether the weight is transposed.
- bias: Optional bias tensor of shape [N] or None.
Returns:
Tuple containing:
- output: Resulting tensor of shape [M // world_size, N].
"""
input, weight, transposed_weight, bias = data
M, local_K = input.shape
if not transposed_weight:
weight = weight.T
N = weight.shape[1]
world_size = torch.distributed.get_world_size()
# matmul
output = torch.matmul(input, weight)
if bias is not None:
output = output + bias
# reduce scatter
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, output)
return rs_output
13 changes: 13 additions & 0 deletions problems/amd_distributed/gemm-rs/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import TypedDict, TypeVar, Tuple, Dict
import torch

input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict])
output_t = TypeVar("output_t", bound=Tuple[torch.Tensor, Dict])


class TestSpec(TypedDict):
world_size: int
m: int
n: int
k: int
seed: int
61 changes: 61 additions & 0 deletions problems/amd_distributed/gemm-rs/task.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# name: gemm-rs

files:
- {"name": "submission.py", "source": "@SUBMISSION@"}
- {"name": "task.py", "source": "task.py"}
- {"name": "utils.py", "source": "../utils.py"}
- {"name": "reference.py", "source": "reference.py"}
- {"name": "eval.py", "source": "../eval.py"}

lang: "py"

description: |
Implement a Gemm-ReduceScatter kernel for efficient transformer models
on a single MI300X device.

ReduceScatter-Gemm (RS-Gemm) is a technique that combines the ReduceScatter
communication pattern with General Matrix Multiplication (GEMM) to optimize
the performance of transformer models on GPUs. It is particularly useful for
handling large models that exceed the memory capacity of a single GPU by
distributing the model across multiple GPUs and efficiently scattering the
results of matrix multiplications.

Your task:
- Implement the Gemm-RS kernel to perform matrix multiplications in a
distributed manner, leveraging the ReduceScatter operation to distribute
data across multiple GPUs.
- Ensure that the implementation is optimized for the MI300X architecture,
taking advantage of its specific hardware features for maximum performance.

Input:
- `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, transposed_weight: bool,
bias: Optional, None or torch.Tensor, TP_GROUP: group object)
- input: Local input tensor of shape [M, local_K].
- weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True.
- transposed_weight: Whether the weight is transposed.
- bias: bias tensor of shape [N] or None.
- TP_GROUP: Process group for tensor parallelism

Output:
- Tuple containing:
- output: Resulting tensor of shape [M // world_size, N]

config:
main: "eval.py"

templates:
Python: "submission.py"

ranking_by: "geom"

tests:
- {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "seed": 42}
- {"world_size": 8, "m": 8192, "n": 4096, "k": 12288, "seed": 6635}
- {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "seed": 4422}
- {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "seed": 1536}


benchmarks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add various different shapes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shapes definition updated.

- {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "seed": 7168}
- {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "seed": 1024}
- {"world_size": 8, "m": 8192, "n": 8192, "k": 30720, "seed": 2035}