@@ -1454,6 +1454,57 @@ def _(scale_tensor):
1454
1454
padded_cols = n_col_blocks * 4
1455
1455
1456
1456
return scale_tensor .new_empty ((padded_rows , padded_cols ))
1457
+
1458
+ @triton .jit
1459
+ def fp32_cast_to_fp4x2_triton_kernel (
1460
+ x_ptr ,
1461
+ q_ptr ,
1462
+ stride_xm ,
1463
+ stride_xn ,
1464
+ M ,
1465
+ N ,
1466
+ ):
1467
+ pid_m = tl .program_id (1 )
1468
+ pid_n = tl .program_id (0 )
1469
+ offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1470
+ offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
1471
+ mask = None
1472
+ other = None
1473
+ x = tl .load (
1474
+ x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
1475
+ ) # [128, 64]
1476
+ x_blocks = x .to (tl .float32 ).reshape (128 , 4 , 16 ) # [128, 4, 16]
1477
+ # Convert to FP4
1478
+ x_fp4x2 = convert_fp32_to_fp4_packed (x_blocks .reshape (128 , 32 , 2 ).split ())
1479
+ offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1480
+ offs_n = pid_n * 32 + tl .arange (0 , 32 )[None , :]
1481
+ mask = (offs_m < M ) & (offs_n < N // 2 )
1482
+ tl .store (q_ptr + offs_m * (N // 2 ) + offs_n , x_fp4x2 , mask = mask )
1483
+
1484
+ def triton_fp32_cast_to_fp4x2 (x : torch .Tensor ) -> torch .Tensor :
1485
+ """
1486
+ Input: a float32 tensor with shape (M, N)
1487
+ Output: a uint8 tensor with shape (M, N // 2), with the values being the result
1488
+ of casting each original value to fp4_e2m1, and then packing fp4x2
1489
+
1490
+ TODO(future PR): optimize performance, lowest hanging fruit is we want
1491
+ to add an e8m0 scale and scale the incoming tensor inside of this kernel
1492
+ TODO(future PR): better checks for shapes, etc
1493
+ TODO(future PR): integrate into training/inference
1494
+ TODO(future PR): integrate with compile, ideally allowing fusion
1495
+ """
1496
+ M , N = x .shape
1497
+ xq = x .new_empty (M , N // 2 , dtype = torch .uint8 )
1498
+ grid = (triton .cdiv (N , 64 ), triton .cdiv (M , 128 ))
1499
+ fp32_cast_to_fp4x2_triton_kernel [grid ](
1500
+ x ,
1501
+ xq ,
1502
+ x .stride (0 ),
1503
+ x .stride (1 ),
1504
+ M ,
1505
+ N ,
1506
+ )
1507
+ return xq .view (torch .uint8 )
1457
1508
else :
1458
1509
1459
1510
def triton_to_mxfp8_dim1 (
@@ -1475,6 +1526,9 @@ def triton_quantize_nvfp4(
1475
1526
) -> Tuple [torch .Tensor , torch .Tensor ]:
1476
1527
raise AssertionError ("needs torch version 2.8+ and triton" )
1477
1528
1529
+ def triton_fp32_cast_to_fp4x2 (x : torch .Tensor ) -> torch .Tensor :
1530
+ raise AssertionError ("needs torch version 2.8+ and triton" )
1531
+
1478
1532
1479
1533
# MXFP8 CUDA kernel is only built on SM100+
1480
1534
if is_sm_at_least_100 ():
0 commit comments