-
Notifications
You must be signed in to change notification settings - Fork 45
add reference impl of gemm-reducescatter #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+204
−0
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
dfced0b
add reference impl of gemm-reducescatter
35eb4e1
typo fix
8e35bf1
remove transposed_weight bool variable
11303be
update custom_kernel and yaml doc
38225f6
add has_bias in test spec
5afcc47
update test and benchmark shapes
10015ea
change dtype
75bd4ab
Feat: works
S1ro1 0d0ca84
Final
S1ro1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from task import input_t, output_t | ||
import torch | ||
|
||
|
||
def generate_input(rank: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t: | ||
""" | ||
Generate random input and weights for the Gemm-ReduceScatter operation. | ||
|
||
Returns: | ||
Tuple of ( | ||
input: torch.Tensor, # [M, local_K] | ||
weight: torch.Tensor, # [N, local_K] | ||
bias: Optional[torch.Tensor], # [N] or None | ||
) | ||
""" | ||
device = torch.device(f'cuda:{rank}') | ||
gen = torch.Generator(device=device) | ||
gen.manual_seed(seed + rank) | ||
|
||
assert m % world_size == 0, "m must be divisible by world_size" | ||
assert k % world_size == 0, "k must be divisible by world_size" | ||
local_k = k // world_size | ||
|
||
# Generate random inputs and weights | ||
input = (torch.rand((m, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 | ||
weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 | ||
|
||
bias = None | ||
if has_bias: | ||
gen.manual_seed(seed) | ||
bias = (torch.rand((n,), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 | ||
|
||
return (input, weight, bias) | ||
|
||
|
||
def ref_kernel(data: input_t) -> output_t: | ||
""" | ||
Reference kernel for Gemm-ReduceScatter operation. | ||
|
||
Args: | ||
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) | ||
- input: Local input tensor of shape [M, local_K]. | ||
- weight: Weight tensor of shape [N, local_K]. | ||
- bias: Optional bias tensor of shape [N] or None. | ||
Returns: | ||
Tuple containing: | ||
- output: Resulting tensor of shape [M // world_size, N]. | ||
""" | ||
input, weight, bias = data | ||
M, local_K = input.shape | ||
N = weight.shape[0] | ||
world_size = torch.distributed.get_world_size() | ||
# matmul | ||
output = torch.matmul(input, weight.T) | ||
if bias is not None: | ||
output = output + bias | ||
# reduce scatter | ||
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) | ||
torch.distributed.reduce_scatter_tensor(rs_output, output) | ||
return rs_output | ||
|
||
|
||
def check_implementation(data: input_t, output: output_t): | ||
expected = ref_kernel(data) | ||
if output.device != expected.device: | ||
return False, f"Output device mismatch: {output.device} != {expected.device}" | ||
res = torch.allclose(output, expected, rtol=1e-2, atol=1e-2) | ||
if not res: | ||
return False, f"Output values mismatch, {output} != {expected}" | ||
|
||
return True, "" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from task import input_t, output_t | ||
import torch | ||
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
""" | ||
Reference kernel for Gemm-ReduceScatter operation. | ||
|
||
Args: | ||
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) | ||
- input: Local input tensor of shape [M, local_K]. | ||
- weight: Weight tensor of shape [N, local_K]. | ||
- bias: Optional bias tensor of shape [N] or None. | ||
Returns: | ||
Tuple containing: | ||
- output: Resulting tensor of shape [M // world_size, N]. | ||
""" | ||
input, weight, bias = data | ||
M, local_K = input.shape | ||
N = weight.shape[0] | ||
world_size = torch.distributed.get_world_size() | ||
# matmul | ||
output = torch.matmul(input, weight.T) | ||
if bias is not None: | ||
output = output + bias | ||
# reduce scatter | ||
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) | ||
torch.distributed.reduce_scatter_tensor(rs_output, output) | ||
return rs_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import TypedDict, TypeVar, Tuple, Optional | ||
import torch | ||
|
||
input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]) | ||
output_t = TypeVar("output_t", bound=torch.Tensor) | ||
|
||
|
||
class TestSpec(TypedDict): | ||
world_size: int | ||
m: int | ||
n: int | ||
k: int | ||
has_bias: bool | ||
seed: int |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# name: gemm-rs | ||
|
||
files: | ||
- {"name": "submission.py", "source": "@SUBMISSION@"} | ||
- {"name": "task.py", "source": "task.py"} | ||
- {"name": "utils.py", "source": "../utils.py"} | ||
- {"name": "reference.py", "source": "reference.py"} | ||
- {"name": "eval.py", "source": "../eval.py"} | ||
|
||
lang: "py" | ||
multi_gpu: true | ||
|
||
description: | | ||
Implement a Gemm-ReduceScatter kernel on a single MI300X node. | ||
|
||
Gemm-ReduceScatter is a technique that combines the ReduceScatter | ||
communication pattern with General Matrix Multiplication (GEMM) to optimize | ||
the performance of transformer models on GPUs. It is particularly useful for | ||
handling large models that exceed the memory capacity of a single GPU by | ||
distributing the model across multiple GPUs and efficiently scattering the | ||
results of matrix multiplications. | ||
|
||
Your task: | ||
- Implement the Gemm-RS kernel to perform matrix multiplications in a | ||
distributed manner, leveraging the ReduceScatter operation to distribute | ||
data across multiple GPUs. | ||
- Ensure that the implementation is optimized for the MI300X architecture, | ||
taking advantage of its specific hardware features for maximum performance. | ||
|
||
Input: | ||
- `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, | ||
bias: Optional, None or torch.Tensor) | ||
- input: Local input tensor of shape [M, local_K]. | ||
- weight: Weight tensor of shape [N, local_K]. | ||
- bias: bias tensor of shape [N] or None. | ||
|
||
Output: | ||
- Tuple containing: | ||
- output: Resulting tensor of shape [M // world_size, N] | ||
|
||
The ranking criteria is the geometric mean of the benchmark results. | ||
|
||
For the grand price, your kernel will be evaluated against the speed of light | ||
analysis and AMD implementations, the solution closest to the speed of light | ||
and AMD implementations will be awarded the grand price. | ||
``` | ||
The speed of light analysis is: | ||
m n k has_bias time[us] | ||
64 7168 18432 False 6.46 | ||
512 4096 12288 True 8.19 | ||
2048 2880 2880 True 23.04 | ||
4096 4096 4096 False 65.54 | ||
8192 4096 14336 True 131.07 | ||
8192 8192 29568 False 379.43 | ||
``` | ||
config: | ||
main: "eval.py" | ||
|
||
templates: | ||
Python: "submission.py" | ||
|
||
ranking_by: "geom" | ||
ranked_timeout: 360 # just in case | ||
|
||
tests: | ||
- {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035} | ||
- {"world_size": 8, "m": 64, "n": 3584, "k": 14336, "has_bias": True, "seed": 13} | ||
- {"world_size": 8, "m": 512, "n": 3584, "k": 14336, "has_bias": True, "seed": 4297} | ||
- {"world_size": 8, "m": 512, "n": 4608, "k": 36864, "has_bias": False, "seed": 1597} | ||
- {"world_size": 8, "m": 2048, "n": 4096, "k": 7168, "has_bias": False, "seed": 716} | ||
- {"world_size": 8, "m": 2048, "n": 8192, "k": 30720, "has_bias": False, "seed": 20201} | ||
- {"world_size": 8, "m": 4096, "n": 2880, "k": 2880, "has_bias": True, "seed": 136} | ||
- {"world_size": 8, "m": 4096, "n": 8192, "k": 2048, "has_bias": True, "seed": 138} | ||
- {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 748} | ||
- {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422} | ||
- {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536} | ||
|
||
|
||
benchmarks: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add various different shapes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shapes definition updated. |
||
- {"world_size": 8, "m": 64, "n": 7168, "k": 18432, "has_bias": False, "seed": 1234} | ||
- {"world_size": 8, "m": 512, "n": 4096, "k": 12288, "has_bias": True, "seed": 663} | ||
- {"world_size": 8, "m": 2048, "n": 2880, "k": 2880, "has_bias": True, "seed": 166} | ||
- {"world_size": 8, "m": 4096, "n": 4096, "k": 4096, "has_bias": False, "seed": 1371} | ||
- {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168} | ||
- {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 42} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.