@@ -358,13 +358,38 @@ def backward(ctx, grad_out: torch.Tensor):
358
358
359
359
# B_data shape: (E, K, N)
360
360
# B_scale shape: (E, K, N//block_size)
361
- B_scales , B_data = to_mx (
361
+ B_scales_ref , B_data_ref = to_mx (
362
362
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363
363
B_t .contiguous (),
364
364
elem_dtype = torch .float8_e4m3fn ,
365
365
block_size = block_size ,
366
366
)
367
367
368
+ # Experiment with cuda kernel
369
+ B = B_t .transpose (- 2 , - 1 )
370
+ E , N , K = B .shape
371
+ B_reshaped = B .reshape (E * N , K )
372
+ B_t_mx = _to_mxfp8_dim1_kernel_wrapper (
373
+ B_reshaped ,
374
+ block_size ,
375
+ elem_dtype = torch .float8_e4m3fn ,
376
+ hp_dtype = B_reshaped .dtype ,
377
+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
378
+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
379
+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
380
+ )
381
+ B_data = B_t_mx .qdata .t () # (K, E*N) -> (E*N, K)
382
+ B_data = B_data .reshape (E , N , K ) # (E*N, K) -> (E, N, K)
383
+ B_scales = B_t_mx ._scale_e8m0 .view (torch .uint8 ) # (K, E*N//block_size)
384
+ B_scales = B_scales .reshape (
385
+ K , E , N // block_size
386
+ ) # (K, E*N//block_size) -> (K, E, N//block_size)
387
+ B_scales = B_scales .permute (
388
+ 1 , 0 , 2
389
+ ) # (K, E, N//block_size) -> (E, K, N//block_size)
390
+ B_scales = B_scales .view (torch .float8_e8m0fnu )
391
+ B_data = B_data .transpose (- 2 , - 1 ).contiguous ().transpose (- 2 , - 1 )
392
+
368
393
# Convert scales to blocked format for 2d-3d grouped mm
369
394
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups (
370
395
grad_out_scale ,
@@ -376,7 +401,7 @@ def backward(ctx, grad_out: torch.Tensor):
376
401
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377
402
grad_A = torch ._scaled_grouped_mm (
378
403
grad_out_data ,
379
- B_data . transpose ( - 2 , - 1 ) ,
404
+ B_data ,
380
405
grad_out_scales_blocked ,
381
406
B_scales_blocked ,
382
407
offs = offs ,
@@ -606,3 +631,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606
631
# Perform bf16 grouped GEMM using dequantized A and B.
607
632
out = torch ._grouped_mm (A , B , offs = offs , out_dtype = out_dtype )
608
633
return out
634
+
635
+
636
+ def round_up (x , y ):
637
+ return ((x + y - 1 ) // y ) * y
0 commit comments