Skip to content

Commit 8299b35

Browse files
[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d
1 parent 01bfca9 commit 8299b35

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,38 @@ def backward(ctx, grad_out: torch.Tensor):
358358

359359
# B_data shape: (E, K, N)
360360
# B_scale shape: (E, K, N//block_size)
361-
B_scales, B_data = to_mx(
361+
B_scales_ref, B_data_ref = to_mx(
362362
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363363
B_t.contiguous(),
364364
elem_dtype=torch.float8_e4m3fn,
365365
block_size=block_size,
366366
)
367367

368+
# Experiment with cuda kernel
369+
B = B_t.transpose(-2, -1)
370+
E, N, K = B.shape
371+
B_reshaped = B.reshape(E * N, K)
372+
B_t_mx = _to_mxfp8_dim1_kernel_wrapper(
373+
B_reshaped,
374+
block_size,
375+
elem_dtype=torch.float8_e4m3fn,
376+
hp_dtype=B_reshaped.dtype,
377+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
378+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
379+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
380+
)
381+
B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K)
382+
B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K)
383+
B_scales = B_t_mx._scale_e8m0.view(torch.uint8) # (K, E*N//block_size)
384+
B_scales = B_scales.reshape(
385+
K, E, N // block_size
386+
) # (K, E*N//block_size) -> (K, E, N//block_size)
387+
B_scales = B_scales.permute(
388+
1, 0, 2
389+
) # (K, E, N//block_size) -> (E, K, N//block_size)
390+
B_scales = B_scales.view(torch.float8_e8m0fnu)
391+
B_data = B_data.transpose(-2, -1).contiguous().transpose(-2, -1)
392+
368393
# Convert scales to blocked format for 2d-3d grouped mm
369394
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
370395
grad_out_scale,
@@ -376,7 +401,7 @@ def backward(ctx, grad_out: torch.Tensor):
376401
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377402
grad_A = torch._scaled_grouped_mm(
378403
grad_out_data,
379-
B_data.transpose(-2, -1),
404+
B_data,
380405
grad_out_scales_blocked,
381406
B_scales_blocked,
382407
offs=offs,
@@ -606,3 +631,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606631
# Perform bf16 grouped GEMM using dequantized A and B.
607632
out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype)
608633
return out
634+
635+
636+
def round_up(x, y):
637+
return ((x + y - 1) // y) * y

0 commit comments

Comments
 (0)