|
| 1 | +# name: gemm-rs |
| 2 | + |
| 3 | +files: |
| 4 | + - {"name": "submission.py", "source": "@SUBMISSION@"} |
| 5 | + - {"name": "task.py", "source": "task.py"} |
| 6 | + - {"name": "utils.py", "source": "../utils.py"} |
| 7 | + - {"name": "reference.py", "source": "reference.py"} |
| 8 | + - {"name": "eval.py", "source": "../eval.py"} |
| 9 | + |
| 10 | +lang: "py" |
| 11 | +multi_gpu: true |
| 12 | + |
| 13 | +description: | |
| 14 | + Implement a Gemm-ReduceScatter kernel on a single MI300X node. |
| 15 | +
|
| 16 | + Gemm-ReduceScatter is a technique that combines the ReduceScatter |
| 17 | + communication pattern with General Matrix Multiplication (GEMM) to optimize |
| 18 | + the performance of transformer models on GPUs. It is particularly useful for |
| 19 | + handling large models that exceed the memory capacity of a single GPU by |
| 20 | + distributing the model across multiple GPUs and efficiently scattering the |
| 21 | + results of matrix multiplications. |
| 22 | +
|
| 23 | + Your task: |
| 24 | + - Implement the Gemm-RS kernel to perform matrix multiplications in a |
| 25 | + distributed manner, leveraging the ReduceScatter operation to distribute |
| 26 | + data across multiple GPUs. |
| 27 | + - Ensure that the implementation is optimized for the MI300X architecture, |
| 28 | + taking advantage of its specific hardware features for maximum performance. |
| 29 | +
|
| 30 | + Input: |
| 31 | + - `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, |
| 32 | + bias: Optional, None or torch.Tensor) |
| 33 | + - input: Local input tensor of shape [M, local_K]. |
| 34 | + - weight: Weight tensor of shape [N, local_K]. |
| 35 | + - bias: bias tensor of shape [N] or None. |
| 36 | +
|
| 37 | + Output: |
| 38 | + - Tuple containing: |
| 39 | + - output: Resulting tensor of shape [M // world_size, N] |
| 40 | +
|
| 41 | + The ranking criteria is the geometric mean of the benchmark results. |
| 42 | +
|
| 43 | + For the grand price, your kernel will be evaluated against the speed of light |
| 44 | + analysis and AMD implementations, the solution closest to the speed of light |
| 45 | + and AMD implementations will be awarded the grand price. |
| 46 | + ``` |
| 47 | + The speed of light analysis is: |
| 48 | + m n k has_bias time[us] |
| 49 | + 64 7168 18432 False 6.46 |
| 50 | + 512 4096 12288 True 8.19 |
| 51 | + 2048 2880 2880 True 23.04 |
| 52 | + 4096 4096 4096 False 65.54 |
| 53 | + 8192 4096 14336 True 131.07 |
| 54 | + 8192 8192 29568 False 379.43 |
| 55 | + ``` |
| 56 | +config: |
| 57 | + main: "eval.py" |
| 58 | + |
| 59 | +templates: |
| 60 | + Python: "submission.py" |
| 61 | + |
| 62 | +ranking_by: "geom" |
| 63 | +ranked_timeout: 360 # just in case |
| 64 | + |
| 65 | +tests: |
| 66 | + - {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035} |
| 67 | + - {"world_size": 8, "m": 64, "n": 3584, "k": 14336, "has_bias": True, "seed": 13} |
| 68 | + - {"world_size": 8, "m": 512, "n": 3584, "k": 14336, "has_bias": True, "seed": 4297} |
| 69 | + - {"world_size": 8, "m": 512, "n": 4608, "k": 36864, "has_bias": False, "seed": 1597} |
| 70 | + - {"world_size": 8, "m": 2048, "n": 4096, "k": 7168, "has_bias": False, "seed": 716} |
| 71 | + - {"world_size": 8, "m": 2048, "n": 8192, "k": 30720, "has_bias": False, "seed": 20201} |
| 72 | + - {"world_size": 8, "m": 4096, "n": 2880, "k": 2880, "has_bias": True, "seed": 136} |
| 73 | + - {"world_size": 8, "m": 4096, "n": 8192, "k": 2048, "has_bias": True, "seed": 138} |
| 74 | + - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 748} |
| 75 | + - {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422} |
| 76 | + - {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536} |
| 77 | + |
| 78 | + |
| 79 | +benchmarks: |
| 80 | + - {"world_size": 8, "m": 64, "n": 7168, "k": 18432, "has_bias": False, "seed": 1234} |
| 81 | + - {"world_size": 8, "m": 512, "n": 4096, "k": 12288, "has_bias": True, "seed": 663} |
| 82 | + - {"world_size": 8, "m": 2048, "n": 2880, "k": 2880, "has_bias": True, "seed": 166} |
| 83 | + - {"world_size": 8, "m": 4096, "n": 4096, "k": 4096, "has_bias": False, "seed": 1371} |
| 84 | + - {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168} |
| 85 | + - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 42} |
0 commit comments