Skip to content

Commit 92144cb

Browse files
[mxfp8 moe training] wrap 3d quantize tensor in custom ops and integrate it
stack-info: PR: #3004, branch: danielvegamyhre/stack/71
1 parent 6403a25 commit 92144cb

File tree

8 files changed

+103
-21
lines changed

8 files changed

+103
-21
lines changed

benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from benchmarks.utils import benchmark_cuda_function_in_microseconds
1818
from torchao.float8.config import ScalingGranularity
1919
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
20-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
20+
from torchao.prototype.moe_training.kernels.mxfp8 import (
2121
torch_to_blocked_2d_M_groups,
2222
torch_to_blocked_per_group_3d,
2323
)

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16-
from torchao.prototype import mxfp8_cuda
16+
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_3d
1717
from torchao.prototype.moe_training.scaled_grouped_mm import (
1818
_to_mxfp8_dim1_3d,
1919
)
@@ -110,9 +110,9 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
110110
)
111111

112112
# bench 3d cuda kernel
113-
data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor)
113+
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor)
114114
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
115-
mxfp8_cuda.quantize_3d,
115+
mxfp8_quantize_cuda_3d,
116116
input_tensor,
117117
)
118118

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tqdm import tqdm
1515

1616
from benchmarks.utils import benchmark_cuda_function_in_microseconds
17-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
17+
from torchao.prototype.moe_training.kernels.mxfp8 import (
1818
compute_blocked_scale_offsets_for_M_groups,
1919
torch_to_blocked_2d_M_groups,
2020
triton_mx_block_rearrange_2d_M_groups,

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
16+
from torchao.prototype.moe_training.kernels.mxfp8 import (
1717
torch_to_blocked_per_group_3d,
1818
triton_mx_block_rearrange_per_group_3d,
1919
)

test/prototype/moe_training/test_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
triton_fp8_per_group_colwise_scales,
2121
triton_fp8_per_group_rowwise_scales,
2222
)
23-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
23+
from torchao.prototype.moe_training.kernels.mxfp8 import (
2424
compute_blocked_scale_offsets_for_K_groups,
2525
compute_blocked_scale_offsets_for_M_groups,
2626
torch_to_blocked_2d_K_groups,

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py renamed to torchao/prototype/moe_training/kernels/mxfp8.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Tuple
23

34
import torch
@@ -7,7 +8,10 @@
78
from torch.library import triton_op, wrap_triton
89

910
from torchao.prototype.mx_formats.utils import to_blocked
10-
from torchao.utils import ceil_div
11+
from torchao.utils import (
12+
ceil_div,
13+
is_sm_at_least_100,
14+
)
1115

1216

1317
def torch_to_blocked_2d_M_groups(
@@ -645,3 +649,77 @@ def _dest_indices_for_block(
645649
# Flatten
646650
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
647651
return dest_indices_flat
652+
653+
654+
mxfp8_cuda_extension_available = False
655+
if is_sm_at_least_100():
656+
try:
657+
# MXFP8 CUDA kernel is only built on SM100+. Furthermore,
658+
# currently our CI runners are not SM100+, so the user needs to build
659+
# from source.
660+
# TODO(#2932): improve this
661+
from torchao.prototype import mxfp8_cuda
662+
663+
mxfp8_cuda_extension_available = True
664+
except ImportError:
665+
logging.debug("Skipping import of torchao.prototype.mxfp8_cuda")
666+
667+
if mxfp8_cuda_extension_available:
668+
# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
669+
# Currently we have to use an arbitrary string because custom ops don't support enum
670+
# params.
671+
@torch.library.custom_op("torchao::mxfp8_quantize_cuda_3d", mutates_args=())
672+
def mxfp8_quantize_cuda_3d(
673+
x: torch.Tensor,
674+
block_size: int = 32,
675+
scaling_mode: str = "floor",
676+
) -> Tuple[torch.Tensor, torch.Tensor]:
677+
"""
678+
Quantizes a 3D tensor of shape (E,N,K) to MXFP8 format, scaling along N.
679+
680+
Args:
681+
x (torch.Tensor): Input tensor to be quantized.
682+
block_size (int, optional): Block size for quantization. Defaults to 32.
683+
scaling_mode (str, optional): Scaling mode for quantization. Defaults to "floor".
684+
685+
Returns:
686+
torch.Tensor: quantized tensor
687+
torch.Tensor: scales tensor
688+
"""
689+
assert x.ndim == 3, "Input tensor must be 3D"
690+
assert x.dtype in (torch.float32, torch.bfloat16), (
691+
"Input tensor must be float32 or bfloat16"
692+
)
693+
q_data, scales = mxfp8_cuda.quantize_3d(
694+
x, scale_dim_n=block_size, scaling_mode=scaling_mode
695+
)
696+
return q_data, scales
697+
698+
@mxfp8_quantize_cuda_3d.register_fake
699+
def _fake_mxfp8_quantize_cuda_3d(
700+
x: torch.Tensor,
701+
block_size: int = 32,
702+
scaling_mode: str = "floor",
703+
) -> Tuple[torch.Tensor, torch.Tensor]:
704+
assert x.ndim == 3, "Input tensor must be 3D"
705+
assert x.dtype in (torch.float32, torch.bfloat16), (
706+
"Input tensor must be float32 or bfloat16"
707+
)
708+
E, N, K = x.shape
709+
# Quantized tensor is in column major layouts
710+
q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided(
711+
x.shape, (N * K, 1, N)
712+
)
713+
scales = x.new_empty((E, N // block_size, K), dtype=torch.float8_e8m0fnu)
714+
return q_data, scales
715+
716+
else:
717+
718+
def mxfp8_quantize_cuda_3d(
719+
x: torch.Tensor,
720+
block_size: int = 32,
721+
scaling_mode: str = "floor",
722+
) -> Tuple[torch.Tensor, torch.Tensor]:
723+
raise NotImplementedError(
724+
"mxfp8_quantize_cuda_3d is not implemented on this device"
725+
)

torchao/prototype/moe_training/kernels/mxfp8_gemms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
5+
from torchao.prototype.moe_training.kernels.mxfp8 import (
66
torch_to_blocked_2d_M_groups,
77
torch_to_blocked_per_group_3d,
88
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
triton_fp8_per_group_colwise_scales,
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
19-
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
19+
from torchao.prototype.moe_training.kernels.mxfp8 import (
2020
compute_blocked_scale_offsets_for_K_groups,
2121
compute_blocked_scale_offsets_for_M_groups,
22+
mxfp8_quantize_cuda_3d,
2223
triton_mx_block_rearrange_2d_K_groups,
2324
triton_mx_block_rearrange_2d_M_groups,
2425
triton_mx_block_rearrange_per_group_3d,
@@ -354,18 +355,20 @@ def backward(ctx, grad_out: torch.Tensor):
354355
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
355356
)
356357

357-
# B_data shape: (E, K, N)
358-
# B_scale shape: (E, K, N//block_size)
359-
B_scales_ref, B_data_ref = to_mx(
360-
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
361-
B_t.contiguous(),
362-
elem_dtype=torch.float8_e4m3fn,
363-
block_size=block_size,
364-
)
365-
366-
# Experiment with cuda kernel
358+
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
359+
# (E, K, N) -> (E, N, K)
367360
B = B_t.transpose(-2, -1)
368-
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)
361+
E, N, K = B.shape
362+
363+
# mxfp8_quantize_cuda_3d is only faster for E > 8
364+
if E > 8:
365+
B_data, B_scales = mxfp8_quantize_cuda_3d(
366+
B._data if hasattr(B, "_data") else B, block_size=block_size
367+
)
368+
# (E, N//block_size, K) -> (E, K, N//block_size)
369+
B_scales = B_scales.transpose(-2, -1)
370+
else:
371+
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)
369372

370373
# Convert scales to blocked format for 2d-3d grouped mm
371374
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
@@ -400,6 +403,7 @@ def backward(ctx, grad_out: torch.Tensor):
400403
grad_out_t_scales = grad_out_t_mx._scale_e8m0
401404

402405
# Transpose A so we can scale along the M dimension, then un-transpose.
406+
# A shape: (M, K)
403407
# A_t_data shape: (K, M)
404408
# A_t_scales shape: (K, M//block_size)
405409
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(

0 commit comments

Comments
 (0)