-
Notifications
You must be signed in to change notification settings - Fork 56
add reference impl of gemm-reducescatter #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
dfced0b
add reference impl of gemm-reducescatter
35eb4e1
typo fix
8e35bf1
remove transposed_weight bool variable
11303be
update custom_kernel and yaml doc
38225f6
add has_bias in test spec
5afcc47
update test and benchmark shapes
10015ea
change dtype
75bd4ab
Feat: works
S1ro1 0d0ca84
Final
S1ro1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from utils import make_match_reference | ||
from task import input_t, output_t | ||
import torch | ||
|
||
|
||
def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, seed: int) -> input_t: | ||
""" | ||
Generate random input and weights for the AG-GEMM operation. | ||
|
||
Returns: | ||
Tuple of ( | ||
input: torch.Tensor, # [M, local_K] | ||
weight: torch.Tensor, # [N, local_K] | ||
transposed_weight: bool, # Whether the weight is transposed | ||
wangxunx marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
bias: Optional[torch.Tensor], # [N] or None | ||
wangxunx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
""" | ||
gen = torch.Generator(device='cuda') | ||
gen.manual_seed(seed + RANK) | ||
|
||
assert m % world_size == 0, "m must be divisible by world_size" | ||
assert k % world_size == 0, "k must be divisible by world_size" | ||
local_k = k // world_size | ||
|
||
# Generate random inputs and weights | ||
input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 | ||
wangxunx marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01 | ||
|
||
return (input, weight, False, None) | ||
|
||
|
||
def ref_kernel(data: input_t) -> output_t: | ||
""" | ||
Reference kernel for Gemm-ReduceScatter operation. | ||
|
||
Args: | ||
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool, | ||
bias: Optional[torch.Tensor]) | ||
- input: Local input tensor of shape [M, local_K]. | ||
- weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. | ||
- transposed_weight: Whether the weight is transposed. | ||
- bias: Optional bias tensor of shape [N] or None. | ||
Returns: | ||
Tuple containing: | ||
- output: Resulting tensor of shape [M // world_size, N]. | ||
""" | ||
input, weight, transposed_weight, bias = data | ||
M, local_K = input.shape | ||
if not transposed_weight: | ||
weight = weight.T | ||
N = weight.shape[1] | ||
world_size = torch.distributed.get_world_size() | ||
# matmul | ||
output = torch.matmul(input, weight) | ||
if bias is not None: | ||
output = output + bias | ||
# reduce scatter | ||
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) | ||
torch.distributed.reduce_scatter_tensor(rs_output, output) | ||
return rs_output | ||
|
||
|
||
check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from task import input_t, output_t | ||
import torch | ||
|
||
|
||
def custom_kernel(data: input_t) -> output_t: | ||
""" | ||
Reference kernel for Gemm-ReduceScatter operation. | ||
|
||
Args: | ||
data: Tuple of (input: torch.Tensor, weight: torch.Tensor, transposed_weight: bool, | ||
bias: Optional[torch.Tensor]) | ||
- input: Local input tensor of shape [M, local_K]. | ||
- weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. | ||
- transposed_weight: Whether the weight is transposed. | ||
- bias: Optional bias tensor of shape [N] or None. | ||
Returns: | ||
Tuple containing: | ||
- output: Resulting tensor of shape [M // world_size, N]. | ||
""" | ||
input, weight, transposed_weight, bias = data | ||
M, local_K = input.shape | ||
if not transposed_weight: | ||
weight = weight.T | ||
N = weight.shape[1] | ||
world_size = torch.distributed.get_world_size() | ||
# matmul | ||
output = torch.matmul(input, weight) | ||
if bias is not None: | ||
output = output + bias | ||
# reduce scatter | ||
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) | ||
torch.distributed.reduce_scatter_tensor(rs_output, output) | ||
return rs_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import TypedDict, TypeVar, Tuple, Dict | ||
import torch | ||
|
||
input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict]) | ||
output_t = TypeVar("output_t", bound=Tuple[torch.Tensor, Dict]) | ||
|
||
|
||
class TestSpec(TypedDict): | ||
world_size: int | ||
m: int | ||
n: int | ||
k: int | ||
seed: int |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# name: gemm-rs | ||
|
||
files: | ||
- {"name": "submission.py", "source": "@SUBMISSION@"} | ||
- {"name": "task.py", "source": "task.py"} | ||
- {"name": "utils.py", "source": "../utils.py"} | ||
- {"name": "reference.py", "source": "reference.py"} | ||
- {"name": "eval.py", "source": "../eval.py"} | ||
|
||
lang: "py" | ||
|
||
description: | | ||
Implement a Gemm-ReduceScatter kernel for efficient transformer models | ||
on a single MI300X device. | ||
|
||
ReduceScatter-Gemm (RS-Gemm) is a technique that combines the ReduceScatter | ||
communication pattern with General Matrix Multiplication (GEMM) to optimize | ||
the performance of transformer models on GPUs. It is particularly useful for | ||
handling large models that exceed the memory capacity of a single GPU by | ||
distributing the model across multiple GPUs and efficiently scattering the | ||
results of matrix multiplications. | ||
|
||
Your task: | ||
- Implement the Gemm-RS kernel to perform matrix multiplications in a | ||
distributed manner, leveraging the ReduceScatter operation to distribute | ||
data across multiple GPUs. | ||
- Ensure that the implementation is optimized for the MI300X architecture, | ||
taking advantage of its specific hardware features for maximum performance. | ||
|
||
Input: | ||
- `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, transposed_weight: bool, | ||
bias: Optional, None or torch.Tensor, TP_GROUP: group object) | ||
- input: Local input tensor of shape [M, local_K]. | ||
- weight: Weight tensor of shape [N, local_K] or [local_K, N] if transed_weight is True. | ||
- transposed_weight: Whether the weight is transposed. | ||
- bias: bias tensor of shape [N] or None. | ||
- TP_GROUP: Process group for tensor parallelism | ||
|
||
Output: | ||
- Tuple containing: | ||
- output: Resulting tensor of shape [M // world_size, N] | ||
|
||
config: | ||
main: "eval.py" | ||
|
||
templates: | ||
Python: "submission.py" | ||
|
||
ranking_by: "geom" | ||
|
||
tests: | ||
- {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "seed": 42} | ||
- {"world_size": 8, "m": 8192, "n": 4096, "k": 12288, "seed": 6635} | ||
- {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "seed": 4422} | ||
- {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "seed": 1536} | ||
|
||
|
||
benchmarks: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add various different shapes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shapes definition updated. |
||
- {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "seed": 7168} | ||
- {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "seed": 1024} | ||
- {"world_size": 8, "m": 8192, "n": 8192, "k": 30720, "seed": 2035} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.