Skip to content

Commit 0ebe1b4

Browse files
[mxfp8 moe training] wrap 3d quantize tensor in custom ops and integrate it
1 parent b3b709c commit 0ebe1b4

File tree

8 files changed

+97
-20
lines changed

8 files changed

+97
-20
lines changed

benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from benchmarks.utils import benchmark_cuda_function_in_microseconds
1818
from torchao.float8.config import ScalingGranularity
1919
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
20-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
20+
from torchao.prototype.moe_training.kernels.mxfp8 import (
2121
torch_to_blocked_2d_M_groups,
2222
torch_to_blocked_per_group_3d,
2323
)

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16-
from torchao.prototype import mxfp8_cuda
16+
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_3d
1717
from torchao.prototype.moe_training.scaled_grouped_mm import (
1818
_to_mxfp8_dim1_3d,
1919
)
@@ -110,9 +110,9 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
110110
)
111111

112112
# bench 3d cuda kernel
113-
data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor)
113+
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor)
114114
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
115-
mxfp8_cuda.quantize_3d,
115+
mxfp8_quantize_cuda_3d,
116116
input_tensor,
117117
)
118118

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tqdm import tqdm
1515

1616
from benchmarks.utils import benchmark_cuda_function_in_microseconds
17-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
17+
from torchao.prototype.moe_training.kernels.mxfp8 import (
1818
compute_blocked_scale_offsets_for_M_groups,
1919
torch_to_blocked_2d_M_groups,
2020
triton_mx_block_rearrange_2d_M_groups,

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
16+
from torchao.prototype.moe_training.kernels.mxfp8 import (
1717
torch_to_blocked_per_group_3d,
1818
triton_mx_block_rearrange_per_group_3d,
1919
)

test/prototype/moe_training/test_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
triton_fp8_per_group_colwise_scales,
2121
triton_fp8_per_group_rowwise_scales,
2222
)
23-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
23+
from torchao.prototype.moe_training.kernels.mxfp8 import (
2424
compute_blocked_scale_offsets_for_K_groups,
2525
compute_blocked_scale_offsets_for_M_groups,
2626
torch_to_blocked_2d_K_groups,

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py renamed to torchao/prototype/moe_training/kernels/mxfp8.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Tuple
23

34
import torch
@@ -7,7 +8,10 @@
78
from torch.library import triton_op, wrap_triton
89

910
from torchao.prototype.mx_formats.utils import to_blocked
10-
from torchao.utils import ceil_div
11+
from torchao.utils import (
12+
ceil_div,
13+
is_sm_at_least_100,
14+
)
1115

1216

1317
def torch_to_blocked_2d_M_groups(
@@ -620,3 +624,77 @@ def _dest_indices_for_block(
620624
# Flatten
621625
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
622626
return dest_indices_flat
627+
628+
629+
mxfp8_cuda_extension_available = False
630+
if is_sm_at_least_100():
631+
try:
632+
# MXFP8 CUDA kernel is only built on SM100+. Furthermore,
633+
# currently our CI runners are not SM100+, so the user needs to build
634+
# from source.
635+
# TODO(#2932): improve this
636+
from torchao.prototype import mxfp8_cuda
637+
638+
mxfp8_cuda_extension_available = True
639+
except ImportError:
640+
logging.debug("Skipping import of torchao.prototype.mxfp8_cuda")
641+
642+
if mxfp8_cuda_extension_available:
643+
# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
644+
# Currently we have to use an arbitrary string because custom ops don't support enum
645+
# params.
646+
@torch.library.custom_op("torchao::mxfp8_quantize_cuda_3d", mutates_args=())
647+
def mxfp8_quantize_cuda_3d(
648+
x: torch.Tensor,
649+
block_size: int = 32,
650+
scaling_mode: str = "floor",
651+
) -> Tuple[torch.Tensor, torch.Tensor]:
652+
"""
653+
Quantizes a 3D tensor of shape (E,N,K) to MXFP8 format, scaling along N.
654+
655+
Args:
656+
x (torch.Tensor): Input tensor to be quantized.
657+
block_size (int, optional): Block size for quantization. Defaults to 32.
658+
scaling_mode (str, optional): Scaling mode for quantization. Defaults to "floor".
659+
660+
Returns:
661+
torch.Tensor: quantized tensor
662+
torch.Tensor: scales tensor
663+
"""
664+
assert x.ndim == 3, "Input tensor must be 3D"
665+
assert x.dtype in (torch.float32, torch.bfloat16), (
666+
"Input tensor must be float32 or bfloat16"
667+
)
668+
q_data, scales = mxfp8_cuda.quantize_3d(
669+
x, scale_dim_n=block_size, scaling_mode=scaling_mode
670+
)
671+
return q_data, scales
672+
673+
@mxfp8_quantize_cuda_3d.register_fake
674+
def _fake_mxfp8_quantize_cuda_3d(
675+
x: torch.Tensor,
676+
block_size: int = 32,
677+
scaling_mode: str = "floor",
678+
) -> Tuple[torch.Tensor, torch.Tensor]:
679+
assert x.ndim == 3, "Input tensor must be 3D"
680+
assert x.dtype in (torch.float32, torch.bfloat16), (
681+
"Input tensor must be float32 or bfloat16"
682+
)
683+
E, N, K = x.shape
684+
# Quantized tensor is in column major layouts
685+
q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided(
686+
x.shape, (N * K, 1, N)
687+
)
688+
scales = x.new_empty((E, N // block_size, K), dtype=torch.float8_e8m0fnu)
689+
return q_data, scales
690+
691+
else:
692+
693+
def mxfp8_quantize_cuda_3d(
694+
x: torch.Tensor,
695+
block_size: int = 32,
696+
scaling_mode: str = "floor",
697+
) -> Tuple[torch.Tensor, torch.Tensor]:
698+
raise NotImplementedError(
699+
"mxfp8_quantize_cuda_3d is not implemented on this device"
700+
)

torchao/prototype/moe_training/kernels/mxfp8_gemms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
5+
from torchao.prototype.moe_training.kernels.mxfp8 import (
66
torch_to_blocked_2d_M_groups,
77
torch_to_blocked_per_group_3d,
88
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
triton_fp8_per_group_colwise_scales,
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
19-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
19+
from torchao.prototype.moe_training.kernels.mxfp8 import (
2020
compute_blocked_scale_offsets_for_K_groups,
2121
compute_blocked_scale_offsets_for_M_groups,
22+
mxfp8_quantize_cuda_3d,
2223
triton_mx_block_rearrange_2d_K_groups,
2324
triton_mx_block_rearrange_2d_M_groups,
2425
triton_mx_block_rearrange_per_group_3d,
@@ -354,18 +355,15 @@ def backward(ctx, grad_out: torch.Tensor):
354355
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
355356
)
356357

357-
# B_data shape: (E, K, N)
358-
# B_scale shape: (E, K, N//block_size)
359-
B_scales_ref, B_data_ref = to_mx(
360-
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
361-
B_t.contiguous(),
362-
elem_dtype=torch.float8_e4m3fn,
363-
block_size=block_size,
358+
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
359+
# (E, K, N) -> (E, N, K)
360+
B = B_t.transpose(-2, -1)
361+
B_data, B_scales = mxfp8_quantize_cuda_3d(
362+
B._data if hasattr(B, "_data") else B, block_size=block_size
364363
)
365364

366-
# Experiment with cuda kernel
367-
B = B_t.transpose(-2, -1)
368-
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)
365+
# (E, N//block_size, K) -> (E, K, N//block_size)
366+
B_scales = B_scales.transpose(-2, -1)
369367

370368
# Convert scales to blocked format for 2d-3d grouped mm
371369
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
@@ -400,6 +398,7 @@ def backward(ctx, grad_out: torch.Tensor):
400398
grad_out_t_scales = grad_out_t_mx._scale_e8m0
401399

402400
# Transpose A so we can scale along the M dimension, then un-transpose.
401+
# A shape: (M, K)
403402
# A_t_data shape: (K, M)
404403
# A_t_scales shape: (K, M//block_size)
405404
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(

0 commit comments

Comments
 (0)