Skip to content

Commit 8d05770

Browse files
committed
[wip] mx: expose a fast path for casting to fp4x2
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ed69b9f ghstack-comment-id: 3210931181 Pull-Request: #2832
1 parent d37dcb7 commit 8d05770

File tree

5 files changed

+159
-9
lines changed

5 files changed

+159
-9
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ def to_mx_dim0_reference(
5959
block_size,
6060
scaling_mode=ScaleCalculationMode.FLOOR,
6161
target_dtype=torch.float8_e4m3fn,
62+
use_fp32_to_fp4_triton_kernel=False,
6263
):
63-
scale_d0, data_d0 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode)
64+
scale_d0, data_d0 = to_mx(
65+
x_hp,
66+
target_dtype,
67+
block_size,
68+
scaling_mode=scaling_mode,
69+
use_fp32_to_fp4_triton_kernel=use_fp32_to_fp4_triton_kernel,
70+
)
6471
return data_d0, scale_d0
6572

6673

@@ -96,6 +103,7 @@ def run(
96103
"dim0_dim1",
97104
"dim0_mxfp8_floor",
98105
"dim0_mxfp4_floor",
106+
"dim0_mxfp4_triton_floor",
99107
"dim0_mxfp8_rceil",
100108
"dim1_mxfp8_floor",
101109
"dim1_mxfp8_rceil",
@@ -204,6 +212,40 @@ def run(
204212
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
205213
bps = (bytes_r + bytes_w) / (time_us / 1e6)
206214

215+
elif mode == "dim0_mxfp4_triton_floor":
216+
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
217+
y_d0, s_d0 = to_mx_dim0_reference_c(
218+
x,
219+
BLOCK_SIZE,
220+
target_dtype=torch.float4_e2m1fn_x2,
221+
use_fp32_to_fp4_triton_kernel=True,
222+
)
223+
224+
for _ in range(2):
225+
__ = to_mx_dim0_reference_c(
226+
x,
227+
BLOCK_SIZE,
228+
target_dtype=torch.float4_e2m1fn_x2,
229+
use_fp32_to_fp4_triton_kernel=True,
230+
)
231+
time_us = benchmark_cuda_function_in_microseconds(
232+
lambda x, b: to_mx_dim0_reference_c(
233+
x,
234+
BLOCK_SIZE,
235+
target_dtype=torch.float4_e2m1fn_x2,
236+
use_fp32_to_fp4_triton_kernel=True,
237+
),
238+
x,
239+
BLOCK_SIZE,
240+
)
241+
242+
# TODO(future PR): make to_mx return float4 directly
243+
assert y_d0.dtype == torch.uint8
244+
assert s_d0.dtype == torch.float8_e8m0fnu
245+
bytes_r = x.numel() * bytes_per_el_bf16
246+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
247+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
248+
207249
elif mode == "dim0_mxfp8_rceil":
208250
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
209251
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)

test/prototype/mx_formats/test_kernels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,29 @@ def test_cuda_mx_dim1_invalid_block_size():
561561
scale_dim_x=1,
562562
scale_dim_y=invalid_block_size,
563563
)
564+
565+
566+
def _fp32_to_fp4_reference(
567+
data_hp: torch.Tensor,
568+
) -> torch.Tensor:
569+
data_lp = f32_to_f4_unpacked(data_hp.float())
570+
data_lp = pack_uint4(data_lp)
571+
return data_lp
572+
573+
574+
@pytest.mark.skipif(
575+
not is_sm_at_least_100(),
576+
reason="requires CUDA capability 10.0 or greater",
577+
)
578+
def test_fp32_cast_to_fp4x2():
579+
from torchao.prototype.mx_formats.kernels import triton_fp32_cast_to_fp4x2
580+
581+
M, K = 16, 16
582+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
583+
# make x's range be the representable range of fp4
584+
x = x * 6.0
585+
586+
data_ref = _fp32_to_fp4_reference(x)
587+
data = triton_fp32_cast_to_fp4x2(x)
588+
torch.testing.assert_close(data_ref, data, atol=0, rtol=0)
589+
assert data.shape == (M, K // 2)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,22 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode):
9696
_test_mx(data, elem_dtype, block_size, scale_calculation_mode)
9797

9898

99+
def test_fp4_triton_cast_does_not_change_numerics():
100+
# TODO(before land): proper skips
101+
# TODO(before land): test rank 3
102+
data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
103+
data_mx_ref = MXTensor.to_mx(
104+
data, torch.float4_e2m1fn_x2, 32, use_fp32_to_fp4_triton_kernel=False
105+
)
106+
data_mx = MXTensor.to_mx(
107+
data, torch.float4_e2m1fn_x2, 32, use_fp32_to_fp4_triton_kernel=True
108+
)
109+
torch.testing.assert_close(data_mx_ref.qdata, data_mx.qdata, atol=0, rtol=0)
110+
torch.testing.assert_close(
111+
data_mx_ref._scale_e8m0, data_mx._scale_e8m0, atol=0, rtol=0
112+
)
113+
114+
99115
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
100116
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
101117
def test_all_zeros(elem_dtype):

torchao/prototype/mx_formats/kernels.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,57 @@ def _(scale_tensor):
14541454
padded_cols = n_col_blocks * 4
14551455

14561456
return scale_tensor.new_empty((padded_rows, padded_cols))
1457+
1458+
@triton.jit
1459+
def fp32_cast_to_fp4x2_triton_kernel(
1460+
x_ptr,
1461+
q_ptr,
1462+
stride_xm,
1463+
stride_xn,
1464+
M,
1465+
N,
1466+
):
1467+
pid_m = tl.program_id(1)
1468+
pid_n = tl.program_id(0)
1469+
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1470+
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
1471+
mask = None
1472+
other = None
1473+
x = tl.load(
1474+
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
1475+
) # [128, 64]
1476+
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
1477+
# Convert to FP4
1478+
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
1479+
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1480+
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
1481+
mask = (offs_m < M) & (offs_n < N // 2)
1482+
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask)
1483+
1484+
def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
1485+
"""
1486+
Input: a float32 tensor with shape (M, N)
1487+
Output: a uint8 tensor with shape (M, N // 2), with the values being the result
1488+
of casting each original value to fp4_e2m1, and then packing fp4x2
1489+
1490+
TODO(future PR): optimize performance, lowest hanging fruit is we want
1491+
to add an e8m0 scale and scale the incoming tensor inside of this kernel
1492+
TODO(future PR): better checks for shapes, etc
1493+
TODO(future PR): integrate into training/inference
1494+
TODO(future PR): integrate with compile, ideally allowing fusion
1495+
"""
1496+
M, N = x.shape
1497+
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
1498+
grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
1499+
fp32_cast_to_fp4x2_triton_kernel[grid](
1500+
x,
1501+
xq,
1502+
x.stride(0),
1503+
x.stride(1),
1504+
M,
1505+
N,
1506+
)
1507+
return xq.view(torch.uint8)
14571508
else:
14581509

14591510
def triton_to_mxfp8_dim1(
@@ -1475,6 +1526,9 @@ def triton_quantize_nvfp4(
14751526
) -> Tuple[torch.Tensor, torch.Tensor]:
14761527
raise AssertionError("needs torch version 2.8+ and triton")
14771528

1529+
def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
1530+
raise AssertionError("needs torch version 2.8+ and triton")
1531+
14781532

14791533
# MXFP8 CUDA kernel is only built on SM100+
14801534
if is_sm_at_least_100():

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
pack_uint6,
5656
triton_f6_e2m3_to_scaled_bf16,
5757
triton_f6_e3m2_to_scaled_bf16,
58+
triton_fp32_cast_to_fp4x2,
5859
unpack_uint4,
5960
)
6061
from torchao.quantization.quantize_.common import (
@@ -134,6 +135,7 @@ def to_mx(
134135
block_size: int,
135136
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
136137
pack_fp6: bool = False,
138+
use_fp32_to_fp4_triton_kernel: bool = False,
137139
):
138140
"""
139141
Takes a high precision tensor and converts to MX scale and raw data, in
@@ -309,13 +311,17 @@ def to_mx(
309311
# need to reshape at the end to help inductor fuse things
310312
data_lp = data_lp.reshape(orig_shape)
311313
elif elem_dtype == torch.float4_e2m1fn_x2:
312-
# can't reshape at the end without handling it in the packing code,
313-
# punt until later since we'll need to rethink the torch.compile
314-
# approach for fp4x2 in any case
315-
data_lp = data_lp.reshape(orig_shape)
316-
data_lp = f32_to_f4_unpacked(data_lp)
317-
orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2]
318-
data_lp = pack_uint4(data_lp)
314+
if use_fp32_to_fp4_triton_kernel:
315+
data_lp = data_lp.reshape(orig_shape)
316+
data_lp = triton_fp32_cast_to_fp4x2(data_lp)
317+
else:
318+
# can't reshape at the end without handling it in the packing code,
319+
# punt until later since we'll need to rethink the torch.compile
320+
# approach for fp4x2 in any case
321+
data_lp = data_lp.reshape(orig_shape)
322+
data_lp = f32_to_f4_unpacked(data_lp)
323+
orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2]
324+
data_lp = pack_uint4(data_lp)
319325
else:
320326
raise AssertionError("unsupported")
321327

@@ -583,9 +589,15 @@ def to_mx(
583589
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED,
584590
pack_fp6: bool = False,
585591
act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None,
592+
use_fp32_to_fp4_triton_kernel: bool = False,
586593
):
587594
scale_e8m0_biased, data_lp = to_mx(
588-
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
595+
data_hp,
596+
elem_dtype,
597+
block_size,
598+
scaling_mode,
599+
pack_fp6,
600+
use_fp32_to_fp4_triton_kernel,
589601
)
590602
if isinstance(scale_e8m0_biased, DTensor):
591603
assert isinstance(data_lp, DTensor), "unsupported"

0 commit comments

Comments
 (0)