@@ -58,15 +58,13 @@ def _scaled_grouped_mm(
58
58
"""
59
59
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
60
60
if scaling_type == MoEScalingType .FP8_ROWWISE :
61
- logger .debug ("Using fp8 rowwise for _scaled_grouped_mm" )
62
61
return _Float8GroupedMM .apply (
63
62
A ,
64
63
B_t ,
65
64
offs ,
66
65
out_dtype ,
67
66
)
68
67
elif scaling_type == MoEScalingType .MXFP8 :
69
- logger .debug ("Using mxfp8 for _scaled_grouped_mm" )
70
68
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
71
69
return _MXFP8GroupedMM .apply (
72
70
A ,
@@ -358,13 +356,17 @@ def backward(ctx, grad_out: torch.Tensor):
358
356
359
357
# B_data shape: (E, K, N)
360
358
# B_scale shape: (E, K, N//block_size)
361
- B_scales , B_data = to_mx (
359
+ B_scales_ref , B_data_ref = to_mx (
362
360
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363
361
B_t .contiguous (),
364
362
elem_dtype = torch .float8_e4m3fn ,
365
363
block_size = block_size ,
366
364
)
367
365
366
+ # Experiment with cuda kernel
367
+ B = B_t .transpose (- 2 , - 1 )
368
+ B_scales , B_data = _to_mxfp8_dim1_3d (B , block_size = block_size )
369
+
368
370
# Convert scales to blocked format for 2d-3d grouped mm
369
371
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups (
370
372
grad_out_scale ,
@@ -376,21 +378,26 @@ def backward(ctx, grad_out: torch.Tensor):
376
378
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377
379
grad_A = torch ._scaled_grouped_mm (
378
380
grad_out_data ,
379
- B_data . transpose ( - 2 , - 1 ) ,
381
+ B_data ,
380
382
grad_out_scales_blocked ,
381
383
B_scales_blocked ,
382
384
offs = offs ,
383
385
out_dtype = out_dtype ,
384
386
)
385
387
386
- # grad_out_t_data shape: (N, M )
388
+ # grad_out_t_data shape: (M, N )
387
389
# grad_out_t_scales shape: (N, M//block_size)
388
- grad_out_t_scales , grad_out_t_data = to_mx (
389
- # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
390
- grad_out . transpose ( - 2 , - 1 ). contiguous () ,
390
+ grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper (
391
+ grad_out ,
392
+ block_size ,
391
393
elem_dtype = torch .float8_e4m3fn ,
392
- block_size = block_size ,
394
+ hp_dtype = grad_out .dtype ,
395
+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
396
+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
397
+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
393
398
)
399
+ grad_out_t_data = grad_out_t_mx .qdata
400
+ grad_out_t_scales = grad_out_t_mx ._scale_e8m0
394
401
395
402
# Transpose A so we can scale along the M dimension, then un-transpose.
396
403
# A_t_data shape: (K, M)
@@ -412,7 +419,6 @@ def backward(ctx, grad_out: torch.Tensor):
412
419
_ , blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups (
413
420
scale_group_offsets
414
421
)
415
-
416
422
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
417
423
grad_out_t_scales ,
418
424
scale_group_offsets ,
@@ -438,6 +444,40 @@ def backward(ctx, grad_out: torch.Tensor):
438
444
return grad_A , grad_B_t , None , None , None
439
445
440
446
447
+ def _to_mxfp8_dim1_3d (
448
+ B : torch .Tensor ,
449
+ block_size : int = 32 ,
450
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
451
+ """
452
+ Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity.
453
+ """
454
+ E , N , K = B .shape
455
+ B_reshaped = B .reshape (E * N , K )
456
+ B_t_mx = _to_mxfp8_dim1_kernel_wrapper (
457
+ B_reshaped ,
458
+ block_size ,
459
+ elem_dtype = torch .float8_e4m3fn ,
460
+ hp_dtype = B_reshaped .dtype ,
461
+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
462
+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
463
+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
464
+ )
465
+ B_data = B_t_mx .qdata .t () # (K, E*N) -> (E*N, K)
466
+ B_data = B_data .reshape (E , N , K ) # (E*N, K) -> (E, N, K)
467
+ B_scales = B_t_mx ._scale_e8m0 .view (torch .uint8 ) # (K, E*N//block_size)
468
+ B_scales = B_scales .reshape (
469
+ K , E , N // block_size
470
+ ) # (K, E*N//block_size) -> (K, E, N//block_size)
471
+ B_scales = B_scales .permute (
472
+ 1 , 0 , 2
473
+ ) # (K, E, N//block_size) -> (E, K, N//block_size)
474
+ B_scales = B_scales .view (torch .float8_e8m0fnu )
475
+
476
+ # TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major
477
+ B_data = B_data .transpose (- 2 , - 1 ).contiguous ().transpose (- 2 , - 1 )
478
+ return B_scales , B_data
479
+
480
+
441
481
def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
442
482
A_data : torch .Tensor ,
443
483
A_scale : torch .Tensor ,
@@ -606,3 +646,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606
646
# Perform bf16 grouped GEMM using dequantized A and B.
607
647
out = torch ._grouped_mm (A , B , offs = offs , out_dtype = out_dtype )
608
648
return out
649
+
650
+
651
+ def round_up (x , y ):
652
+ return ((x + y - 1 ) // y ) * y
0 commit comments