Skip to content

Commit 644d635

Browse files
[moe training] add benchmarks for dsv3 236b, 671b shapes; reorganize benchmarks dir
stack-info: PR: #2999, branch: danielvegamyhre/stack/68
1 parent ff3ba31 commit 644d635

7 files changed

+59
-41
lines changed
File renamed without changes.

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@
3535
@dataclass(frozen=True)
3636
class ExperimentConfig:
3737
high_precision_dtype: torch.dtype
38-
A_shape: tuple[int]
39-
B_shape: tuple[int]
38+
MNKG: tuple[int]
4039
recipe: MoEScalingType
4140

4241

4342
@dataclass(frozen=True)
4443
class ExperimentResult:
45-
bf16_e2e_us: float
46-
scaled_e2e_us: float
47-
scaled_e2e_speedup: float
44+
bf16_fwd_bwd_us: float
45+
scaled_fwd_bwd_us: float
46+
scaled_fwd_bwd_speedup: float
4847
bf16_fwd_us: float
4948
scaled_fwd_us: float
5049
scaled_fwd_speedup: float
@@ -57,22 +56,46 @@ class Experiment:
5756

5857

5958
def get_configs() -> List[ExperimentConfig]:
60-
# Llama4 shapes
61-
A_shapes = [(16640, 5120)]
62-
B_shapes = [(1, 8192, 5120), (4, 8192, 5120), (16, 8192, 5120), (64, 8192, 5120)]
59+
MNKG_list = [
60+
# Llama4 16e with various experts per device (i.e., different EP degrees)
61+
(16384, 8192, 5120, 1),
62+
(16384, 8192, 5120, 2),
63+
(16384, 8192, 5120, 4),
64+
(16384, 8192, 5120, 8),
65+
(128000, 8192, 5120, 1),
66+
(128000, 8192, 5120, 2),
67+
(128000, 8192, 5120, 4),
68+
(128000, 8192, 5120, 8),
69+
# DSV3 236B with various experts per device (i.e., different EP degrees)
70+
(16384, 1536, 5120, 1),
71+
(16384, 1536, 5120, 2),
72+
(16384, 1536, 5120, 4),
73+
(16384, 1536, 5120, 8),
74+
(128000, 1536, 5120, 1),
75+
(128000, 1536, 5120, 2),
76+
(128000, 1536, 5120, 4),
77+
(128000, 1536, 5120, 8),
78+
# DSV3 671B with various experts per device (i.e., different EP degrees)
79+
(16384, 2048, 7168, 1),
80+
(16384, 2048, 7168, 2),
81+
(16384, 2048, 7168, 4),
82+
(16384, 2048, 7168, 8),
83+
(128000, 2048, 7168, 1),
84+
(128000, 2048, 7168, 2),
85+
(128000, 2048, 7168, 4),
86+
(128000, 2048, 7168, 8),
87+
]
6388
recipes = [MoEScalingType.FP8_ROWWISE, MoEScalingType.MXFP8]
6489
high_precision_dtypes = [torch.bfloat16]
6590
configs = []
66-
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(
67-
A_shapes,
68-
B_shapes,
91+
for MNKG, recipe, high_precision_dtype in itertools.product(
92+
MNKG_list,
6993
recipes,
7094
high_precision_dtypes,
7195
):
7296
configs.append(
7397
ExperimentConfig(
74-
A_shape=A_shape,
75-
B_shape=B_shape,
98+
MNKG=MNKG,
7699
recipe=recipe,
77100
high_precision_dtype=high_precision_dtype,
78101
)
@@ -83,15 +106,17 @@ def get_configs() -> List[ExperimentConfig]:
83106
def run_experiment(
84107
config: ExperimentConfig, args: argparse.Namespace
85108
) -> ExperimentResult:
109+
total_M, N, K, G = config.MNKG
110+
86111
# define test inputs
87112
A = torch.randn(
88-
*config.A_shape,
113+
(total_M, K),
89114
dtype=config.high_precision_dtype,
90115
device=device,
91116
requires_grad=True,
92117
)
93118
B_t = torch.randn(
94-
*config.B_shape,
119+
(G, N, K),
95120
dtype=config.high_precision_dtype,
96121
device=device,
97122
requires_grad=True,
@@ -102,17 +127,15 @@ def run_experiment(
102127
# that occurs in the backward pass of the differentiable scaled grouped mm.
103128
# - the transposed tensor in col-major format with groups along the row dimension,
104129
# which represents the right operand.
105-
n_groups = config.B_shape[0]
106-
Mg = A.shape[0]
107130
token_group_alignment_size = 32 if config.recipe == MoEScalingType.MXFP8 else 16
108-
offs = generate_jagged_offs(n_groups, Mg, multiple_of=token_group_alignment_size)
131+
offs = generate_jagged_offs(G, total_M, multiple_of=token_group_alignment_size)
109132

110133
labels = torch.ones(
111134
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
112135
)
113136

114-
# E2E bf16 benchmark + profiling
115-
bf16_e2e_us = bench_fwd_bwd_microseconds(
137+
# fwd_bwd bf16 benchmark + profiling
138+
bf16_fwd_bwd_us = bench_fwd_bwd_microseconds(
116139
torch._grouped_mm,
117140
A,
118141
B_t,
@@ -133,8 +156,8 @@ def run_experiment(
133156
profile_name="bf16_profile",
134157
)
135158

136-
# E2E scaled benchmark + profiling
137-
scaled_e2e_us = bench_fwd_bwd_microseconds(
159+
# fwd_bwd scaled benchmark + profiling
160+
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
138161
_scaled_grouped_mm,
139162
A,
140163
B_t,
@@ -177,9 +200,9 @@ def run_experiment(
177200
)
178201

179202
return ExperimentResult(
180-
bf16_e2e_us=round(bf16_e2e_us, 3),
181-
scaled_e2e_us=round(scaled_e2e_us, 3),
182-
scaled_e2e_speedup=round(bf16_e2e_us / scaled_e2e_us, 3),
203+
bf16_fwd_bwd_us=round(bf16_fwd_bwd_us, 3),
204+
scaled_fwd_bwd_us=round(scaled_fwd_bwd_us, 3),
205+
scaled_fwd_bwd_speedup=round(bf16_fwd_bwd_us / scaled_fwd_bwd_us, 3),
183206
bf16_fwd_us=round(bf16_fwd_us, 3),
184207
scaled_fwd_us=round(scaled_fwd_us, 3),
185208
scaled_fwd_speedup=round(bf16_fwd_us / scaled_fwd_us, 3),
@@ -188,28 +211,24 @@ def run_experiment(
188211

189212
def print_results(experiments: List[Experiment]):
190213
headers = [
191-
"A_shape",
192-
"B_shape",
214+
"M,N,K,G",
193215
"recipe",
194-
"bf16_e2e_us",
195-
"scaled_e2e_us",
196-
"scaled_e2e_speedup",
216+
"bf16_fwd_bwd_us",
217+
"scaled_fwd_bwd_us",
218+
"scaled_fwd_bwd_speedup",
197219
"bf16_fwd_us",
198220
"scaled_fwd_us",
199221
"scaled_fwd_speedup",
200222
]
201223
rows = []
202224
for experiment in experiments:
203-
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
204-
B_shape = f"({experiment.config.B_shape[0]}, {experiment.config.B_shape[1]}, {experiment.config.B_shape[2]})"
205225
rows.append(
206226
[
207-
A_shape,
208-
B_shape,
227+
str(experiment.config.MNKG),
209228
experiment.config.recipe,
210-
experiment.result.bf16_e2e_us,
211-
experiment.result.scaled_e2e_us,
212-
f"{experiment.result.scaled_e2e_speedup}x",
229+
experiment.result.bf16_fwd_bwd_us,
230+
experiment.result.scaled_fwd_bwd_us,
231+
f"{experiment.result.scaled_fwd_bwd_speedup}x",
213232
experiment.result.bf16_fwd_us,
214233
experiment.result.scaled_fwd_us,
215234
f"{experiment.result.scaled_fwd_speedup}x",
File renamed without changes.
File renamed without changes.
File renamed without changes.

benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py renamed to benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8080

8181
Mg, K = input_shape
8282
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
83+
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
84+
input_group_offsets
85+
)
8386

8487
# bench torch
8588
compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups)
@@ -90,14 +93,10 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
9093
compiled_run_torch,
9194
input_tensor,
9295
input_group_offsets,
93-
Mg,
9496
K,
9597
)
9698

9799
# bench triton
98-
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
99-
input_group_offsets
100-
)
101100
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
102101
input_tensor,
103102
input_group_offsets,
File renamed without changes.

0 commit comments

Comments
 (0)