From dfced0bc7bfe8631df4f77f1293023a7986aca5a Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Tue, 2 Sep 2025 06:21:35 -0500 Subject: [PATCH 1/9] add reference impl of gemm-reducescatter --- problems/amd_distributed/gemm-rs/reference.py | 63 +++++++++++++++++++ .../amd_distributed/gemm-rs/submission.py | 33 ++++++++++ problems/amd_distributed/gemm-rs/task.py | 13 ++++ problems/amd_distributed/gemm-rs/task.yml | 61 ++++++++++++++++++ 4 files changed, 170 insertions(+) create mode 100644 problems/amd_distributed/gemm-rs/reference.py create mode 100644 problems/amd_distributed/gemm-rs/submission.py create mode 100644 problems/amd_distributed/gemm-rs/task.py create mode 100644 problems/amd_distributed/gemm-rs/task.yml diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py new file mode 100644 index 0000000..5ff5904 --- /dev/null +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -0,0 +1,63 @@ +from utils import make_match_reference +from task import input_t, output_t +import torch + + +def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int) -> input_t: + """ + Generate random input and weights for the AG-GEMM operation. + + Returns: + Tuple of ( + input: torch.Tensor, # [M, local_K] + weight: torch.Tensor, # [N, local_K] + transposed_weight: bool, # Whether the weight is transposed + bias: Optional[torch.Tensor], # [N] or None + ) + """ + gen = torch.Generator(device='cuda') + gen.manual_seed(seed + RANK) + + assert m % world_size == 0, "m must be divisible by world_size" + assert k % world_size == 0, "k must be divisible by world_size" + local_k = k // world_size + + # Generate random inputs and weights + input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 + weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 + + return (input, weight, False, None) + + +def ref_kernel(data: input_t) -> output_t: + """ + Reference kernel for Gemm-ReduceScatter operation. + + Args: + data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool, + bias: Optional[torch.Tensor]) + - input: Local input tensor of shape [M, local_K]. + - weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. + - transposed_weight: Whether the weight is transposed. + - bias: Optional bias tensor of shape [N] or None. + Returns: + Tuple containing: + - output: Resulting tensor of shape [M // world_size, N]. + """ + input, weight, transposed_weight, bias = data + M, local_K = input.shape + if not transposed_weight: + weight = weight.T + N = weight.shape[1] + world_size = torch.distributed.get_world_size() + # matmul + output = torch.matmul(input, weight) + if bias is not None: + output = output + bias + # reduce scatter + rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) + torch.distributed.reduce_scatter_tensor(rs_output, output) + return rs_output + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2) diff --git a/problems/amd_distributed/gemm-rs/submission.py b/problems/amd_distributed/gemm-rs/submission.py new file mode 100644 index 0000000..dce77b4 --- /dev/null +++ b/problems/amd_distributed/gemm-rs/submission.py @@ -0,0 +1,33 @@ +from task import input_t, output_t +import torch + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference kernel for Gemm-ReduceScatter operation. + + Args: + data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool, + bias: Optional[torch.Tensor]) + - input: Local input tensor of shape [M, local_K]. + - weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. + - transposed_weight: Whether the weight is transposed. + - bias: Optional bias tensor of shape [N] or None. + Returns: + Tuple containing: + - output: Resulting tensor of shape [M // world_size, N]. + """ + input, weight, transposed_weight, bias = data + M, local_K = input.shape + if not transposed_weight: + weight = weight.T + N = weight.shape[1] + world_size = torch.distributed.get_world_size() + # matmul + output = torch.matmul(input, weight) + if bias is not None: + output = output + bias + # reduce scatter + rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) + torch.distributed.reduce_scatter_tensor(rs_output, output) + return rs_output diff --git a/problems/amd_distributed/gemm-rs/task.py b/problems/amd_distributed/gemm-rs/task.py new file mode 100644 index 0000000..9be0c81 --- /dev/null +++ b/problems/amd_distributed/gemm-rs/task.py @@ -0,0 +1,13 @@ +from typing import TypedDict, TypeVar, Tuple, Dict +import torch + +input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict]) +output_t = TypeVar("output_t", bound=Tuple[torch.Tensor, Dict]) + + +class TestSpec(TypedDict): + world_size: int + m: int + n: int + k: int + seed: int \ No newline at end of file diff --git a/problems/amd_distributed/gemm-rs/task.yml b/problems/amd_distributed/gemm-rs/task.yml new file mode 100644 index 0000000..c8b6bb8 --- /dev/null +++ b/problems/amd_distributed/gemm-rs/task.yml @@ -0,0 +1,61 @@ +# name: gemm-rs + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement a Gemm-ReduceScatter kernel for efficient transformer models + on a single MI300X device. + + ReduceScatter-Gemm (RS-Gemm) is a technique that combines the ReduceScatter + communication pattern with General Matrix Multiplication (GEMM) to optimize + the performance of transformer models on GPUs. It is particularly useful for + handling large models that exceed the memory capacity of a single GPU by + distributing the model across multiple GPUs and efficiently scattering the + results of matrix multiplications. + + Your task: + - Implement the Gemm-RS kernel to perform matrix multiplications in a + distributed manner, leveraging the ReduceScatter operation to distribute + data across multiple GPUs. + - Ensure that the implementation is optimized for the MI300X architecture, + taking advantage of its specific hardware features for maximum performance. + + Input: + - `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, transposed_weight: bool, + bias: Optional, None or torch.Tensor, TP_GROUP: group object) + - input: Local input tensor of shape [M, local_K]. + - weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. + - transposed_weight: Whether the weight is transposed. + - bias: bias tensor of shape [N] or None. + - TP_GROUP: Process group for tensor parallelism + + Output: + - Tuple containing: + - output: Resulting tensor of shape [M // world_size, N] + +config: + main: "eval.py" + +templates: + Python: "submission.py" + +ranking_by: "geom" + +tests: + - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "seed": 42} + - {"world_size": 8, "m": 8192, "n": 4096, "k": 12288, "seed": 6635} + - {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "seed": 4422} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "seed": 1536} + + +benchmarks: + - {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "seed": 7168} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "seed": 1024} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 30720, "seed": 2035} From 35eb4e12cc4990200091f9675cbf786b6536fd2d Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Mon, 8 Sep 2025 00:55:04 -0500 Subject: [PATCH 2/9] typo fix --- problems/amd_distributed/gemm-rs/reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py index 5ff5904..52f4a3a 100644 --- a/problems/amd_distributed/gemm-rs/reference.py +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -5,7 +5,7 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int) -> input_t: """ - Generate random input and weights for the AG-GEMM operation. + Generate random input and weights for the Gemm-ReduceScatter operation. Returns: Tuple of ( From 8e35bf1e9d1140380e41d48e53f854c95455a404 Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Mon, 8 Sep 2025 01:55:56 -0500 Subject: [PATCH 3/9] remove transposed_weight bool variable --- problems/amd_distributed/gemm-rs/reference.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py index 52f4a3a..5a12156 100644 --- a/problems/amd_distributed/gemm-rs/reference.py +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -11,7 +11,6 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int Tuple of ( input: torch.Tensor, # [M, local_K] weight: torch.Tensor, # [N, local_K] - transposed_weight: bool, # Whether the weight is transposed bias: Optional[torch.Tensor], # [N] or None ) """ @@ -26,7 +25,7 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 - return (input, weight, False, None) + return (input, weight, None) def ref_kernel(data: input_t) -> output_t: @@ -34,24 +33,20 @@ def ref_kernel(data: input_t) -> output_t: Reference kernel for Gemm-ReduceScatter operation. Args: - data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool, - bias: Optional[torch.Tensor]) + data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) - input: Local input tensor of shape [M, local_K]. - - weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. - - transposed_weight: Whether the weight is transposed. + - weight: Weight tensor of shape [N, local_K]. - bias: Optional bias tensor of shape [N] or None. Returns: Tuple containing: - output: Resulting tensor of shape [M // world_size, N]. """ - input, weight, transposed_weight, bias = data + input, weight, bias = data M, local_K = input.shape - if not transposed_weight: - weight = weight.T - N = weight.shape[1] + N = weight.shape[0] world_size = torch.distributed.get_world_size() # matmul - output = torch.matmul(input, weight) + output = torch.matmul(input, weight.T) if bias is not None: output = output + bias # reduce scatter From 11303bea37bcf061564a9a37e24e75bb66a96d00 Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Mon, 8 Sep 2025 02:10:22 -0500 Subject: [PATCH 4/9] update custom_kernel and yaml doc --- problems/amd_distributed/gemm-rs/submission.py | 14 +++++--------- problems/amd_distributed/gemm-rs/task.yml | 13 +++++-------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/problems/amd_distributed/gemm-rs/submission.py b/problems/amd_distributed/gemm-rs/submission.py index dce77b4..4212d5a 100644 --- a/problems/amd_distributed/gemm-rs/submission.py +++ b/problems/amd_distributed/gemm-rs/submission.py @@ -7,24 +7,20 @@ def custom_kernel(data: input_t) -> output_t: Reference kernel for Gemm-ReduceScatter operation. Args: - data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool, - bias: Optional[torch.Tensor]) + data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) - input: Local input tensor of shape [M, local_K]. - - weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. - - transposed_weight: Whether the weight is transposed. + - weight: Weight tensor of shape [N, local_K]. - bias: Optional bias tensor of shape [N] or None. Returns: Tuple containing: - output: Resulting tensor of shape [M // world_size, N]. """ - input, weight, transposed_weight, bias = data + input, weight, bias = data M, local_K = input.shape - if not transposed_weight: - weight = weight.T - N = weight.shape[1] + N = weight.shape[0] world_size = torch.distributed.get_world_size() # matmul - output = torch.matmul(input, weight) + output = torch.matmul(input, weight.T) if bias is not None: output = output + bias # reduce scatter diff --git a/problems/amd_distributed/gemm-rs/task.yml b/problems/amd_distributed/gemm-rs/task.yml index c8b6bb8..4d36fdc 100644 --- a/problems/amd_distributed/gemm-rs/task.yml +++ b/problems/amd_distributed/gemm-rs/task.yml @@ -10,10 +10,9 @@ files: lang: "py" description: | - Implement a Gemm-ReduceScatter kernel for efficient transformer models - on a single MI300X device. + Implement a Gemm-ReduceScatter kernel on a single MI300X node. - ReduceScatter-Gemm (RS-Gemm) is a technique that combines the ReduceScatter + Gemm-ReduceScatter is a technique that combines the ReduceScatter communication pattern with General Matrix Multiplication (GEMM) to optimize the performance of transformer models on GPUs. It is particularly useful for handling large models that exceed the memory capacity of a single GPU by @@ -28,13 +27,11 @@ description: | taking advantage of its specific hardware features for maximum performance. Input: - - `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, transposed_weight: bool, - bias: Optional, None or torch.Tensor, TP_GROUP: group object) + - `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, + bias: Optional, None or torch.Tensor) - input: Local input tensor of shape [M, local_K]. - - weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. - - transposed_weight: Whether the weight is transposed. + - weight: Weight tensor of shape [N, local_K]. - bias: bias tensor of shape [N] or None. - - TP_GROUP: Process group for tensor parallelism Output: - Tuple containing: From 38225f614dbf2efd9619c70d9055784c5469a0fe Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Mon, 8 Sep 2025 05:34:36 -0500 Subject: [PATCH 5/9] add has_bias in test spec --- problems/amd_distributed/gemm-rs/reference.py | 9 +++++++-- problems/amd_distributed/gemm-rs/task.py | 1 + problems/amd_distributed/gemm-rs/task.yml | 14 +++++++------- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py index 5a12156..03e4e5d 100644 --- a/problems/amd_distributed/gemm-rs/reference.py +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -3,7 +3,7 @@ import torch -def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int) -> input_t: +def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t: """ Generate random input and weights for the Gemm-ReduceScatter operation. @@ -25,7 +25,12 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 - return (input, weight, None) + bias = None + if has_bias: + gen.manual_seed(seed) + bias = (torch.rand((n,), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 + + return (input, weight, bias) def ref_kernel(data: input_t) -> output_t: diff --git a/problems/amd_distributed/gemm-rs/task.py b/problems/amd_distributed/gemm-rs/task.py index 9be0c81..1245626 100644 --- a/problems/amd_distributed/gemm-rs/task.py +++ b/problems/amd_distributed/gemm-rs/task.py @@ -10,4 +10,5 @@ class TestSpec(TypedDict): m: int n: int k: int + has_bias: bool seed: int \ No newline at end of file diff --git a/problems/amd_distributed/gemm-rs/task.yml b/problems/amd_distributed/gemm-rs/task.yml index 4d36fdc..5988132 100644 --- a/problems/amd_distributed/gemm-rs/task.yml +++ b/problems/amd_distributed/gemm-rs/task.yml @@ -46,13 +46,13 @@ templates: ranking_by: "geom" tests: - - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "seed": 42} - - {"world_size": 8, "m": 8192, "n": 4096, "k": 12288, "seed": 6635} - - {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "seed": 4422} - - {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "seed": 1536} + - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 42} + - {"world_size": 8, "m": 8192, "n": 4096, "k": 12288, "has_bias": False, "seed": 6635} + - {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536} benchmarks: - - {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "seed": 7168} - - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "seed": 1024} - - {"world_size": 8, "m": 8192, "n": 8192, "k": 30720, "seed": 2035} + - {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 1024} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 30720, "has_bias": True, "seed": 2035} From 5afcc47169d44e3d49024622d353c969cd3d27a5 Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Tue, 9 Sep 2025 21:31:17 -0500 Subject: [PATCH 6/9] update test and benchmark shapes --- problems/amd_distributed/gemm-rs/task.yml | 33 ++++++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/problems/amd_distributed/gemm-rs/task.yml b/problems/amd_distributed/gemm-rs/task.yml index 5988132..360ff1e 100644 --- a/problems/amd_distributed/gemm-rs/task.yml +++ b/problems/amd_distributed/gemm-rs/task.yml @@ -37,6 +37,21 @@ description: | - Tuple containing: - output: Resulting tensor of shape [M // world_size, N] + The ranking criteria is the geometric mean of the benchmark results. + + For the grand price, your kernel will be evaluated against the speed of light + analysis and AMD implementations, the solution closest to the speed of light + and AMD implementations will be awarded the grand price. + ``` + The speed of light analysis is: + m n k has_bias time[us] + 64 7168 18432 False 6.46 + 512 4096 12288 True 8.19 + 2048 2880 2880 True 23.04 + 4096 4096 4096 False 65.54 + 8192 4096 14336 True 131.07 + 8192 8192 29568 False 379.43 + ``` config: main: "eval.py" @@ -46,13 +61,23 @@ templates: ranking_by: "geom" tests: - - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 42} - - {"world_size": 8, "m": 8192, "n": 4096, "k": 12288, "has_bias": False, "seed": 6635} + - {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035} + - {"world_size": 8, "m": 64, "n": 3584, "k": 14336, "has_bias": True, "seed": 13} + - {"world_size": 8, "m": 512, "n": 3584, "k": 14336, "has_bias": True, "seed": 4297} + - {"world_size": 8, "m": 512, "n": 4608, "k": 36864, "has_bias": False, "seed": 1597} + - {"world_size": 8, "m": 2048, "n": 4096, "k": 7168, "has_bias": False, "seed": 716} + - {"world_size": 8, "m": 2048, "n": 8192, "k": 30720, "has_bias": False, "seed": 20201} + - {"world_size": 8, "m": 4096, "n": 2880, "k": 2880, "has_bias": True, "seed": 136} + - {"world_size": 8, "m": 4096, "n": 8192, "k": 2048, "has_bias": True, "seed": 138} + - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 748} - {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422} - {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536} benchmarks: + - {"world_size": 8, "m": 64, "n": 7168, "k": 18432, "has_bias": False, "seed": 1234} + - {"world_size": 8, "m": 512, "n": 4096, "k": 12288, "has_bias": True, "seed": 663} + - {"world_size": 8, "m": 2048, "n": 2880, "k": 2880, "has_bias": True, "seed": 166} + - {"world_size": 8, "m": 4096, "n": 4096, "k": 4096, "has_bias": False, "seed": 1371} - {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168} - - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 1024} - - {"world_size": 8, "m": 8192, "n": 8192, "k": 30720, "has_bias": True, "seed": 2035} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 42} From 10015ea46d41d0ffd55056905566f13645e8063f Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Wed, 10 Sep 2025 21:09:37 -0500 Subject: [PATCH 7/9] change dtype --- problems/amd_distributed/gemm-rs/reference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py index 03e4e5d..6cb363d 100644 --- a/problems/amd_distributed/gemm-rs/reference.py +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -22,13 +22,13 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias: local_k = k // world_size # Generate random inputs and weights - input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 - weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 + input = (torch.rand((m, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 + weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 bias = None if has_bias: gen.manual_seed(seed) - bias = (torch.rand((n,), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 + bias = (torch.rand((n,), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 return (input, weight, bias) From 75bd4abc5f4acab1ee1ec38b8a22adb85b6b5d74 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 12 Sep 2025 14:24:43 +0200 Subject: [PATCH 8/9] Feat: works --- problems/amd_distributed.yaml | 5 ++++ problems/amd_distributed/gemm-rs/reference.py | 23 +++++++++++++------ problems/amd_distributed/gemm-rs/task.yml | 1 + 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/problems/amd_distributed.yaml b/problems/amd_distributed.yaml index 86fed95..fc9121f 100644 --- a/problems/amd_distributed.yaml +++ b/problems/amd_distributed.yaml @@ -10,3 +10,8 @@ problems: deadline: "2025-10-14" gpus: - MI300x8 + - directory: amd_distributed/gemm-rs + name: amd-gemm-rs + deadline: "2025-10-14" + gpus: + - MI300x8 diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py index 6cb363d..dcffa08 100644 --- a/problems/amd_distributed/gemm-rs/reference.py +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -3,7 +3,7 @@ import torch -def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t: +def generate_input(rank: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t: """ Generate random input and weights for the Gemm-ReduceScatter operation. @@ -14,21 +14,22 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias: bias: Optional[torch.Tensor], # [N] or None ) """ - gen = torch.Generator(device='cuda') - gen.manual_seed(seed + RANK) + device = torch.device(f'cuda:{rank}') + gen = torch.Generator(device=device) + gen.manual_seed(seed + rank) assert m % world_size == 0, "m must be divisible by world_size" assert k % world_size == 0, "k must be divisible by world_size" local_k = k // world_size # Generate random inputs and weights - input = (torch.rand((m, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 - weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 + input = (torch.rand((m, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 + weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 bias = None if has_bias: gen.manual_seed(seed) - bias = (torch.rand((n,), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 + bias = (torch.rand((n,), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 return (input, weight, bias) @@ -60,4 +61,12 @@ def ref_kernel(data: input_t) -> output_t: return rs_output -check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2) +def check_implementation(data: input_t, output: output_t): + expected = ref_kernel(data) + if output.device != expected.device: + return False, f"Output device mismatch: {output.device} != {expected.device}" + res = torch.allclose(output, expected, rtol=1e-2, atol=1e-2) + if not res: + return False, f"Output values mismatch, {output} != {expected}" + + return True, "" diff --git a/problems/amd_distributed/gemm-rs/task.yml b/problems/amd_distributed/gemm-rs/task.yml index 360ff1e..d050136 100644 --- a/problems/amd_distributed/gemm-rs/task.yml +++ b/problems/amd_distributed/gemm-rs/task.yml @@ -8,6 +8,7 @@ files: - {"name": "eval.py", "source": "../eval.py"} lang: "py" +multi_gpu: true description: | Implement a Gemm-ReduceScatter kernel on a single MI300X node. From 0d0ca8484c27a214d9b8f0c6f8b49253b0520da8 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sun, 14 Sep 2025 12:35:15 +0200 Subject: [PATCH 9/9] Final --- problems/amd_distributed/gemm-rs/reference.py | 1 - problems/amd_distributed/gemm-rs/task.py | 8 ++++---- problems/amd_distributed/gemm-rs/task.yml | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/problems/amd_distributed/gemm-rs/reference.py b/problems/amd_distributed/gemm-rs/reference.py index dcffa08..cb60206 100644 --- a/problems/amd_distributed/gemm-rs/reference.py +++ b/problems/amd_distributed/gemm-rs/reference.py @@ -1,4 +1,3 @@ -from utils import make_match_reference from task import input_t, output_t import torch diff --git a/problems/amd_distributed/gemm-rs/task.py b/problems/amd_distributed/gemm-rs/task.py index 1245626..1de3edd 100644 --- a/problems/amd_distributed/gemm-rs/task.py +++ b/problems/amd_distributed/gemm-rs/task.py @@ -1,8 +1,8 @@ -from typing import TypedDict, TypeVar, Tuple, Dict +from typing import TypedDict, TypeVar, Tuple, Optional import torch -input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict]) -output_t = TypeVar("output_t", bound=Tuple[torch.Tensor, Dict]) +input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]) +output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): @@ -11,4 +11,4 @@ class TestSpec(TypedDict): n: int k: int has_bias: bool - seed: int \ No newline at end of file + seed: int diff --git a/problems/amd_distributed/gemm-rs/task.yml b/problems/amd_distributed/gemm-rs/task.yml index d050136..6eac274 100644 --- a/problems/amd_distributed/gemm-rs/task.yml +++ b/problems/amd_distributed/gemm-rs/task.yml @@ -60,6 +60,7 @@ templates: Python: "submission.py" ranking_by: "geom" +ranked_timeout: 360 # just in case tests: - {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035}