Skip to content

Commit 5d874ed

Browse files
[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d
stack-info: PR: #2998, branch: danielvegamyhre/stack/67
1 parent 01bfca9 commit 5d874ed

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

benchmarks/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,13 @@ def fwd_bwd(*args, **kwargs):
2525

2626
def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs):
2727
fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
28+
29+
def inference_fn(*args, **kwargs):
30+
with torch.inference_mode():
31+
return fn_compiled(*args, **kwargs)
32+
2833
return benchmark_cuda_function_in_microseconds(
29-
fn_compiled,
34+
inference_fn,
3035
*args,
3136
**kwargs,
3237
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,13 @@ def _scaled_grouped_mm(
5858
"""
5959
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
6060
if scaling_type == MoEScalingType.FP8_ROWWISE:
61-
logger.debug("Using fp8 rowwise for _scaled_grouped_mm")
6261
return _Float8GroupedMM.apply(
6362
A,
6463
B_t,
6564
offs,
6665
out_dtype,
6766
)
6867
elif scaling_type == MoEScalingType.MXFP8:
69-
logger.debug("Using mxfp8 for _scaled_grouped_mm")
7068
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
7169
return _MXFP8GroupedMM.apply(
7270
A,
@@ -358,13 +356,17 @@ def backward(ctx, grad_out: torch.Tensor):
358356

359357
# B_data shape: (E, K, N)
360358
# B_scale shape: (E, K, N//block_size)
361-
B_scales, B_data = to_mx(
359+
B_scales_ref, B_data_ref = to_mx(
362360
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363361
B_t.contiguous(),
364362
elem_dtype=torch.float8_e4m3fn,
365363
block_size=block_size,
366364
)
367365

366+
# Experiment with cuda kernel
367+
B = B_t.transpose(-2, -1)
368+
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)
369+
368370
# Convert scales to blocked format for 2d-3d grouped mm
369371
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
370372
grad_out_scale,
@@ -376,21 +378,26 @@ def backward(ctx, grad_out: torch.Tensor):
376378
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377379
grad_A = torch._scaled_grouped_mm(
378380
grad_out_data,
379-
B_data.transpose(-2, -1),
381+
B_data,
380382
grad_out_scales_blocked,
381383
B_scales_blocked,
382384
offs=offs,
383385
out_dtype=out_dtype,
384386
)
385387

386-
# grad_out_t_data shape: (N, M)
388+
# grad_out_t_data shape: (M, N)
387389
# grad_out_t_scales shape: (N, M//block_size)
388-
grad_out_t_scales, grad_out_t_data = to_mx(
389-
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
390-
grad_out.transpose(-2, -1).contiguous(),
390+
grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper(
391+
grad_out,
392+
block_size,
391393
elem_dtype=torch.float8_e4m3fn,
392-
block_size=block_size,
394+
hp_dtype=grad_out.dtype,
395+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
396+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
397+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
393398
)
399+
grad_out_t_data = grad_out_t_mx.qdata
400+
grad_out_t_scales = grad_out_t_mx._scale_e8m0
394401

395402
# Transpose A so we can scale along the M dimension, then un-transpose.
396403
# A_t_data shape: (K, M)
@@ -412,7 +419,6 @@ def backward(ctx, grad_out: torch.Tensor):
412419
_, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups(
413420
scale_group_offsets
414421
)
415-
416422
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
417423
grad_out_t_scales,
418424
scale_group_offsets,
@@ -438,6 +444,40 @@ def backward(ctx, grad_out: torch.Tensor):
438444
return grad_A, grad_B_t, None, None, None
439445

440446

447+
def _to_mxfp8_dim1_3d(
448+
B: torch.Tensor,
449+
block_size: int = 32,
450+
) -> tuple[torch.Tensor, torch.Tensor]:
451+
"""
452+
Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity.
453+
"""
454+
E, N, K = B.shape
455+
B_reshaped = B.reshape(E * N, K)
456+
B_t_mx = _to_mxfp8_dim1_kernel_wrapper(
457+
B_reshaped,
458+
block_size,
459+
elem_dtype=torch.float8_e4m3fn,
460+
hp_dtype=B_reshaped.dtype,
461+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
462+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
463+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
464+
)
465+
B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K)
466+
B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K)
467+
B_scales = B_t_mx._scale_e8m0.view(torch.uint8) # (K, E*N//block_size)
468+
B_scales = B_scales.reshape(
469+
K, E, N // block_size
470+
) # (K, E*N//block_size) -> (K, E, N//block_size)
471+
B_scales = B_scales.permute(
472+
1, 0, 2
473+
) # (K, E, N//block_size) -> (E, K, N//block_size)
474+
B_scales = B_scales.view(torch.float8_e8m0fnu)
475+
476+
# TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major
477+
B_data = B_data.transpose(-2, -1).contiguous().transpose(-2, -1)
478+
return B_scales, B_data
479+
480+
441481
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
442482
A_data: torch.Tensor,
443483
A_scale: torch.Tensor,
@@ -606,3 +646,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606646
# Perform bf16 grouped GEMM using dequantized A and B.
607647
out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype)
608648
return out
649+
650+
651+
def round_up(x, y):
652+
return ((x + y - 1) // y) * y

0 commit comments

Comments
 (0)