3
3
import torch
4
4
5
5
6
- def generate_input (RANK : int , world_size : int , m : int , n : int , k : int , has_bias : bool , seed : int ) -> input_t :
6
+ def generate_input (rank : int , world_size : int , m : int , n : int , k : int , has_bias : bool , seed : int ) -> input_t :
7
7
"""
8
8
Generate random input and weights for the Gemm-ReduceScatter operation.
9
9
@@ -14,21 +14,22 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias:
14
14
bias: Optional[torch.Tensor], # [N] or None
15
15
)
16
16
"""
17
- gen = torch .Generator (device = 'cuda' )
18
- gen .manual_seed (seed + RANK )
17
+ device = torch .device (f'cuda:{ rank } ' )
18
+ gen = torch .Generator (device = device )
19
+ gen .manual_seed (seed + rank )
19
20
20
21
assert m % world_size == 0 , "m must be divisible by world_size"
21
22
assert k % world_size == 0 , "k must be divisible by world_size"
22
23
local_k = k // world_size
23
24
24
25
# Generate random inputs and weights
25
- input = (torch .rand ((m , local_k ), dtype = torch .bfloat16 , device = "cuda" , generator = gen ) * 2 - 1 ) * 0.01
26
- weight = (torch .rand ((n , local_k ), dtype = torch .bfloat16 , device = "cuda" , generator = gen ) * 2 - 1 ) * 0.01
26
+ input = (torch .rand ((m , local_k ), dtype = torch .bfloat16 , device = device , generator = gen ) * 2 - 1 ) * 0.01
27
+ weight = (torch .rand ((n , local_k ), dtype = torch .bfloat16 , device = device , generator = gen ) * 2 - 1 ) * 0.01
27
28
28
29
bias = None
29
30
if has_bias :
30
31
gen .manual_seed (seed )
31
- bias = (torch .rand ((n ,), dtype = torch .bfloat16 , device = "cuda" , generator = gen ) * 2 - 1 ) * 0.01
32
+ bias = (torch .rand ((n ,), dtype = torch .bfloat16 , device = device , generator = gen ) * 2 - 1 ) * 0.01
32
33
33
34
return (input , weight , bias )
34
35
@@ -60,4 +61,12 @@ def ref_kernel(data: input_t) -> output_t:
60
61
return rs_output
61
62
62
63
63
- check_implementation = make_match_reference (ref_kernel , rtol = 1e-2 , atol = 1e-2 )
64
+ def check_implementation (data : input_t , output : output_t ):
65
+ expected = ref_kernel (data )
66
+ if output .device != expected .device :
67
+ return False , f"Output device mismatch: { output .device } != { expected .device } "
68
+ res = torch .allclose (output , expected , rtol = 1e-2 , atol = 1e-2 )
69
+ if not res :
70
+ return False , f"Output values mismatch, { output } != { expected } "
71
+
72
+ return True , ""
0 commit comments