@@ -358,13 +358,17 @@ def backward(ctx, grad_out: torch.Tensor):
358
358
359
359
# B_data shape: (E, K, N)
360
360
# B_scale shape: (E, K, N//block_size)
361
- B_scales , B_data = to_mx (
361
+ B_scales_ref , B_data_ref = to_mx (
362
362
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363
363
B_t .contiguous (),
364
364
elem_dtype = torch .float8_e4m3fn ,
365
365
block_size = block_size ,
366
366
)
367
367
368
+ # Experiment with cuda kernel
369
+ B = B_t .transpose (- 2 , - 1 )
370
+ B_scales , B_data = _to_mxfp8_dim1_3d (B , block_size = block_size )
371
+
368
372
# Convert scales to blocked format for 2d-3d grouped mm
369
373
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups (
370
374
grad_out_scale ,
@@ -376,21 +380,26 @@ def backward(ctx, grad_out: torch.Tensor):
376
380
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377
381
grad_A = torch ._scaled_grouped_mm (
378
382
grad_out_data ,
379
- B_data . transpose ( - 2 , - 1 ) ,
383
+ B_data ,
380
384
grad_out_scales_blocked ,
381
385
B_scales_blocked ,
382
386
offs = offs ,
383
387
out_dtype = out_dtype ,
384
388
)
385
389
386
- # grad_out_t_data shape: (N, M )
390
+ # grad_out_t_data shape: (M, N )
387
391
# grad_out_t_scales shape: (N, M//block_size)
388
- grad_out_t_scales , grad_out_t_data = to_mx (
389
- # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
390
- grad_out . transpose ( - 2 , - 1 ). contiguous () ,
392
+ grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper (
393
+ grad_out ,
394
+ block_size ,
391
395
elem_dtype = torch .float8_e4m3fn ,
392
- block_size = block_size ,
396
+ hp_dtype = grad_out .dtype ,
397
+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
398
+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
399
+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
393
400
)
401
+ grad_out_t_data = grad_out_t_mx .qdata
402
+ grad_out_t_scales = grad_out_t_mx ._scale_e8m0
394
403
395
404
# Transpose A so we can scale along the M dimension, then un-transpose.
396
405
# A_t_data shape: (K, M)
@@ -412,7 +421,6 @@ def backward(ctx, grad_out: torch.Tensor):
412
421
_ , blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups (
413
422
scale_group_offsets
414
423
)
415
-
416
424
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
417
425
grad_out_t_scales ,
418
426
scale_group_offsets ,
@@ -438,6 +446,40 @@ def backward(ctx, grad_out: torch.Tensor):
438
446
return grad_A , grad_B_t , None , None , None
439
447
440
448
449
+ def _to_mxfp8_dim1_3d (
450
+ B : torch .Tensor ,
451
+ block_size : int = 32 ,
452
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
453
+ """
454
+ Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity.
455
+ """
456
+ E , N , K = B .shape
457
+ B_reshaped = B .reshape (E * N , K )
458
+ B_t_mx = _to_mxfp8_dim1_kernel_wrapper (
459
+ B_reshaped ,
460
+ block_size ,
461
+ elem_dtype = torch .float8_e4m3fn ,
462
+ hp_dtype = B_reshaped .dtype ,
463
+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
464
+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
465
+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
466
+ )
467
+ B_data = B_t_mx .qdata .t () # (K, E*N) -> (E*N, K)
468
+ B_data = B_data .reshape (E , N , K ) # (E*N, K) -> (E, N, K)
469
+ B_scales = B_t_mx ._scale_e8m0 .view (torch .uint8 ) # (K, E*N//block_size)
470
+ B_scales = B_scales .reshape (
471
+ K , E , N // block_size
472
+ ) # (K, E*N//block_size) -> (K, E, N//block_size)
473
+ B_scales = B_scales .permute (
474
+ 1 , 0 , 2
475
+ ) # (K, E, N//block_size) -> (E, K, N//block_size)
476
+ B_scales = B_scales .view (torch .float8_e8m0fnu )
477
+
478
+ # TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major
479
+ B_data = B_data .transpose (- 2 , - 1 ).contiguous ().transpose (- 2 , - 1 )
480
+ return B_scales , B_data
481
+
482
+
441
483
def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
442
484
A_data : torch .Tensor ,
443
485
A_scale : torch .Tensor ,
@@ -606,3 +648,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606
648
# Perform bf16 grouped GEMM using dequantized A and B.
607
649
out = torch ._grouped_mm (A , B , offs = offs , out_dtype = out_dtype )
608
650
return out
651
+
652
+
653
+ def round_up (x , y ):
654
+ return ((x + y - 1 ) // y ) * y
0 commit comments