35
35
@dataclass (frozen = True )
36
36
class ExperimentConfig :
37
37
high_precision_dtype : torch .dtype
38
- A_shape : tuple [int ]
39
- B_shape : tuple [int ]
38
+ MNKG : tuple [int ]
40
39
recipe : MoEScalingType
41
40
42
41
43
42
@dataclass (frozen = True )
44
43
class ExperimentResult :
45
- bf16_e2e_us : float
46
- scaled_e2e_us : float
47
- scaled_e2e_speedup : float
44
+ bf16_fwd_bwd_us : float
45
+ scaled_fwd_bwd_us : float
46
+ scaled_fwd_bwd_speedup : float
48
47
bf16_fwd_us : float
49
48
scaled_fwd_us : float
50
49
scaled_fwd_speedup : float
@@ -57,22 +56,46 @@ class Experiment:
57
56
58
57
59
58
def get_configs () -> List [ExperimentConfig ]:
60
- # Llama4 shapes
61
- A_shapes = [(16640 , 5120 )]
62
- B_shapes = [(1 , 8192 , 5120 ), (4 , 8192 , 5120 ), (16 , 8192 , 5120 ), (64 , 8192 , 5120 )]
59
+ MNKG_list = [
60
+ # Llama4 16e with various experts per device (i.e., different EP degrees)
61
+ (16384 , 8192 , 5120 , 1 ),
62
+ (16384 , 8192 , 5120 , 2 ),
63
+ (16384 , 8192 , 5120 , 4 ),
64
+ (16384 , 8192 , 5120 , 8 ),
65
+ (128000 , 8192 , 5120 , 1 ),
66
+ (128000 , 8192 , 5120 , 2 ),
67
+ (128000 , 8192 , 5120 , 4 ),
68
+ (128000 , 8192 , 5120 , 8 ),
69
+ # DSV3 236B with various experts per device (i.e., different EP degrees)
70
+ (16384 , 1536 , 5120 , 1 ),
71
+ (16384 , 1536 , 5120 , 2 ),
72
+ (16384 , 1536 , 5120 , 4 ),
73
+ (16384 , 1536 , 5120 , 8 ),
74
+ (128000 , 1536 , 5120 , 1 ),
75
+ (128000 , 1536 , 5120 , 2 ),
76
+ (128000 , 1536 , 5120 , 4 ),
77
+ (128000 , 1536 , 5120 , 8 ),
78
+ # DSV3 671B with various experts per device (i.e., different EP degrees)
79
+ (16384 , 2048 , 7168 , 1 ),
80
+ (16384 , 2048 , 7168 , 2 ),
81
+ (16384 , 2048 , 7168 , 4 ),
82
+ (16384 , 2048 , 7168 , 8 ),
83
+ (128000 , 2048 , 7168 , 1 ),
84
+ (128000 , 2048 , 7168 , 2 ),
85
+ (128000 , 2048 , 7168 , 4 ),
86
+ (128000 , 2048 , 7168 , 8 ),
87
+ ]
63
88
recipes = [MoEScalingType .FP8_ROWWISE , MoEScalingType .MXFP8 ]
64
89
high_precision_dtypes = [torch .bfloat16 ]
65
90
configs = []
66
- for A_shape , B_shape , recipe , high_precision_dtype in itertools .product (
67
- A_shapes ,
68
- B_shapes ,
91
+ for MNKG , recipe , high_precision_dtype in itertools .product (
92
+ MNKG_list ,
69
93
recipes ,
70
94
high_precision_dtypes ,
71
95
):
72
96
configs .append (
73
97
ExperimentConfig (
74
- A_shape = A_shape ,
75
- B_shape = B_shape ,
98
+ MNKG = MNKG ,
76
99
recipe = recipe ,
77
100
high_precision_dtype = high_precision_dtype ,
78
101
)
@@ -83,15 +106,17 @@ def get_configs() -> List[ExperimentConfig]:
83
106
def run_experiment (
84
107
config : ExperimentConfig , args : argparse .Namespace
85
108
) -> ExperimentResult :
109
+ total_M , N , K , G = config .MNKG
110
+
86
111
# define test inputs
87
112
A = torch .randn (
88
- * config . A_shape ,
113
+ ( total_M , K ) ,
89
114
dtype = config .high_precision_dtype ,
90
115
device = device ,
91
116
requires_grad = True ,
92
117
)
93
118
B_t = torch .randn (
94
- * config . B_shape ,
119
+ ( G , N , K ) ,
95
120
dtype = config .high_precision_dtype ,
96
121
device = device ,
97
122
requires_grad = True ,
@@ -102,17 +127,15 @@ def run_experiment(
102
127
# that occurs in the backward pass of the differentiable scaled grouped mm.
103
128
# - the transposed tensor in col-major format with groups along the row dimension,
104
129
# which represents the right operand.
105
- n_groups = config .B_shape [0 ]
106
- Mg = A .shape [0 ]
107
130
token_group_alignment_size = 32 if config .recipe == MoEScalingType .MXFP8 else 16
108
- offs = generate_jagged_offs (n_groups , Mg , multiple_of = token_group_alignment_size )
131
+ offs = generate_jagged_offs (G , total_M , multiple_of = token_group_alignment_size )
109
132
110
133
labels = torch .ones (
111
134
(A .shape [0 ], B_t .shape [- 1 ]), device = device , dtype = torch .bfloat16
112
135
)
113
136
114
- # E2E bf16 benchmark + profiling
115
- bf16_e2e_us = bench_fwd_bwd_microseconds (
137
+ # fwd_bwd bf16 benchmark + profiling
138
+ bf16_fwd_bwd_us = bench_fwd_bwd_microseconds (
116
139
torch ._grouped_mm ,
117
140
A ,
118
141
B_t ,
@@ -133,8 +156,8 @@ def run_experiment(
133
156
profile_name = "bf16_profile" ,
134
157
)
135
158
136
- # E2E scaled benchmark + profiling
137
- scaled_e2e_us = bench_fwd_bwd_microseconds (
159
+ # fwd_bwd scaled benchmark + profiling
160
+ scaled_fwd_bwd_us = bench_fwd_bwd_microseconds (
138
161
_scaled_grouped_mm ,
139
162
A ,
140
163
B_t ,
@@ -177,9 +200,9 @@ def run_experiment(
177
200
)
178
201
179
202
return ExperimentResult (
180
- bf16_e2e_us = round (bf16_e2e_us , 3 ),
181
- scaled_e2e_us = round (scaled_e2e_us , 3 ),
182
- scaled_e2e_speedup = round (bf16_e2e_us / scaled_e2e_us , 3 ),
203
+ bf16_fwd_bwd_us = round (bf16_fwd_bwd_us , 3 ),
204
+ scaled_fwd_bwd_us = round (scaled_fwd_bwd_us , 3 ),
205
+ scaled_fwd_bwd_speedup = round (bf16_fwd_bwd_us / scaled_fwd_bwd_us , 3 ),
183
206
bf16_fwd_us = round (bf16_fwd_us , 3 ),
184
207
scaled_fwd_us = round (scaled_fwd_us , 3 ),
185
208
scaled_fwd_speedup = round (bf16_fwd_us / scaled_fwd_us , 3 ),
@@ -188,28 +211,24 @@ def run_experiment(
188
211
189
212
def print_results (experiments : List [Experiment ]):
190
213
headers = [
191
- "A_shape" ,
192
- "B_shape" ,
214
+ "M,N,K,G" ,
193
215
"recipe" ,
194
- "bf16_e2e_us " ,
195
- "scaled_e2e_us " ,
196
- "scaled_e2e_speedup " ,
216
+ "bf16_fwd_bwd_us " ,
217
+ "scaled_fwd_bwd_us " ,
218
+ "scaled_fwd_bwd_speedup " ,
197
219
"bf16_fwd_us" ,
198
220
"scaled_fwd_us" ,
199
221
"scaled_fwd_speedup" ,
200
222
]
201
223
rows = []
202
224
for experiment in experiments :
203
- A_shape = f"({ experiment .config .A_shape [0 ]} , { experiment .config .A_shape [1 ]} )"
204
- B_shape = f"({ experiment .config .B_shape [0 ]} , { experiment .config .B_shape [1 ]} , { experiment .config .B_shape [2 ]} )"
205
225
rows .append (
206
226
[
207
- A_shape ,
208
- B_shape ,
227
+ str (experiment .config .MNKG ),
209
228
experiment .config .recipe ,
210
- experiment .result .bf16_e2e_us ,
211
- experiment .result .scaled_e2e_us ,
212
- f"{ experiment .result .scaled_e2e_speedup } x" ,
229
+ experiment .result .bf16_fwd_bwd_us ,
230
+ experiment .result .scaled_fwd_bwd_us ,
231
+ f"{ experiment .result .scaled_fwd_bwd_speedup } x" ,
213
232
experiment .result .bf16_fwd_us ,
214
233
experiment .result .scaled_fwd_us ,
215
234
f"{ experiment .result .scaled_fwd_speedup } x" ,
0 commit comments