Skip to content

Commit 3685390

Browse files
[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d
stack-info: PR: #2998, branch: danielvegamyhre/stack/67
1 parent 01bfca9 commit 3685390

File tree

1 file changed

+54
-8
lines changed

1 file changed

+54
-8
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,17 @@ def backward(ctx, grad_out: torch.Tensor):
358358

359359
# B_data shape: (E, K, N)
360360
# B_scale shape: (E, K, N//block_size)
361-
B_scales, B_data = to_mx(
361+
B_scales_ref, B_data_ref = to_mx(
362362
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363363
B_t.contiguous(),
364364
elem_dtype=torch.float8_e4m3fn,
365365
block_size=block_size,
366366
)
367367

368+
# Experiment with cuda kernel
369+
B = B_t.transpose(-2, -1)
370+
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)
371+
368372
# Convert scales to blocked format for 2d-3d grouped mm
369373
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
370374
grad_out_scale,
@@ -376,21 +380,26 @@ def backward(ctx, grad_out: torch.Tensor):
376380
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377381
grad_A = torch._scaled_grouped_mm(
378382
grad_out_data,
379-
B_data.transpose(-2, -1),
383+
B_data,
380384
grad_out_scales_blocked,
381385
B_scales_blocked,
382386
offs=offs,
383387
out_dtype=out_dtype,
384388
)
385389

386-
# grad_out_t_data shape: (N, M)
390+
# grad_out_t_data shape: (M, N)
387391
# grad_out_t_scales shape: (N, M//block_size)
388-
grad_out_t_scales, grad_out_t_data = to_mx(
389-
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
390-
grad_out.transpose(-2, -1).contiguous(),
392+
grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper(
393+
grad_out,
394+
block_size,
391395
elem_dtype=torch.float8_e4m3fn,
392-
block_size=block_size,
396+
hp_dtype=grad_out.dtype,
397+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
398+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
399+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
393400
)
401+
grad_out_t_data = grad_out_t_mx.qdata
402+
grad_out_t_scales = grad_out_t_mx._scale_e8m0
394403

395404
# Transpose A so we can scale along the M dimension, then un-transpose.
396405
# A_t_data shape: (K, M)
@@ -412,7 +421,6 @@ def backward(ctx, grad_out: torch.Tensor):
412421
_, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups(
413422
scale_group_offsets
414423
)
415-
416424
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
417425
grad_out_t_scales,
418426
scale_group_offsets,
@@ -438,6 +446,40 @@ def backward(ctx, grad_out: torch.Tensor):
438446
return grad_A, grad_B_t, None, None, None
439447

440448

449+
def _to_mxfp8_dim1_3d(
450+
B: torch.Tensor,
451+
block_size: int = 32,
452+
) -> tuple[torch.Tensor, torch.Tensor]:
453+
"""
454+
Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity.
455+
"""
456+
E, N, K = B.shape
457+
B_reshaped = B.reshape(E * N, K)
458+
B_t_mx = _to_mxfp8_dim1_kernel_wrapper(
459+
B_reshaped,
460+
block_size,
461+
elem_dtype=torch.float8_e4m3fn,
462+
hp_dtype=B_reshaped.dtype,
463+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
464+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
465+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
466+
)
467+
B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K)
468+
B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K)
469+
B_scales = B_t_mx._scale_e8m0.view(torch.uint8) # (K, E*N//block_size)
470+
B_scales = B_scales.reshape(
471+
K, E, N // block_size
472+
) # (K, E*N//block_size) -> (K, E, N//block_size)
473+
B_scales = B_scales.permute(
474+
1, 0, 2
475+
) # (K, E, N//block_size) -> (E, K, N//block_size)
476+
B_scales = B_scales.view(torch.float8_e8m0fnu)
477+
478+
# TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major
479+
B_data = B_data.transpose(-2, -1).contiguous().transpose(-2, -1)
480+
return B_scales, B_data
481+
482+
441483
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
442484
A_data: torch.Tensor,
443485
A_scale: torch.Tensor,
@@ -606,3 +648,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606648
# Perform bf16 grouped GEMM using dequantized A and B.
607649
out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype)
608650
return out
651+
652+
653+
def round_up(x, y):
654+
return ((x + y - 1) // y) * y

0 commit comments

Comments
 (0)