diff --git a/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py b/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py index 8caadc4fe3..9c49033a9d 100644 --- a/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py +++ b/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py @@ -17,7 +17,7 @@ from benchmarks.utils import benchmark_cuda_function_in_microseconds from torchao.float8.config import ScalingGranularity from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated -from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( +from torchao.prototype.moe_training.kernels.mxfp8 import ( torch_to_blocked_2d_M_groups, torch_to_blocked_per_group_3d, ) diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py index bd630b2b82..b57ca81d4c 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py @@ -13,7 +13,7 @@ from tqdm import tqdm from benchmarks.utils import benchmark_cuda_function_in_microseconds -from torchao.prototype import mxfp8_cuda +from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_3d from torchao.prototype.moe_training.scaled_grouped_mm import ( _to_mxfp8_dim1_3d, ) @@ -110,9 +110,9 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor: ) # bench 3d cuda kernel - data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor) + data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor) time_cuda_3d_us = benchmark_cuda_function_in_microseconds( - mxfp8_cuda.quantize_3d, + mxfp8_quantize_cuda_3d, input_tensor, ) diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py index 9013516740..b02124b782 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py @@ -14,7 +14,7 @@ from tqdm import tqdm from benchmarks.utils import benchmark_cuda_function_in_microseconds -from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( +from torchao.prototype.moe_training.kernels.mxfp8 import ( compute_blocked_scale_offsets_for_M_groups, torch_to_blocked_2d_M_groups, triton_mx_block_rearrange_2d_M_groups, diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py index 19fbdb3194..296270fe62 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py @@ -13,7 +13,7 @@ from tqdm import tqdm from benchmarks.utils import benchmark_cuda_function_in_microseconds -from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( +from torchao.prototype.moe_training.kernels.mxfp8 import ( torch_to_blocked_per_group_3d, triton_mx_block_rearrange_per_group_3d, ) diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 6a578e7b52..e89b5a6043 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -20,7 +20,7 @@ triton_fp8_per_group_colwise_scales, triton_fp8_per_group_rowwise_scales, ) -from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( +from torchao.prototype.moe_training.kernels.mxfp8 import ( compute_blocked_scale_offsets_for_K_groups, compute_blocked_scale_offsets_for_M_groups, torch_to_blocked_2d_K_groups, diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8.py similarity index 89% rename from torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py rename to torchao/prototype/moe_training/kernels/mxfp8.py index d08fa9e371..353688f185 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8.py @@ -1,3 +1,4 @@ +import logging from typing import Tuple import torch @@ -7,7 +8,10 @@ from torch.library import triton_op, wrap_triton from torchao.prototype.mx_formats.utils import to_blocked -from torchao.utils import ceil_div +from torchao.utils import ( + ceil_div, + is_sm_at_least_100, +) def torch_to_blocked_2d_M_groups( @@ -645,3 +649,77 @@ def _dest_indices_for_block( # Flatten dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) return dest_indices_flat + + +mxfp8_cuda_extension_available = False +if is_sm_at_least_100(): + try: + # MXFP8 CUDA kernel is only built on SM100+. Furthermore, + # currently our CI runners are not SM100+, so the user needs to build + # from source. + # TODO(#2932): improve this + from torchao.prototype import mxfp8_cuda + + mxfp8_cuda_extension_available = True + except ImportError: + logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") + +if mxfp8_cuda_extension_available: + # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. + # Currently we have to use an arbitrary string because custom ops don't support enum + # params. + @torch.library.custom_op("torchao::mxfp8_quantize_cuda_3d", mutates_args=()) + def mxfp8_quantize_cuda_3d( + x: torch.Tensor, + block_size: int = 32, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes a 3D tensor of shape (E,N,K) to MXFP8 format, scaling along N. + + Args: + x (torch.Tensor): Input tensor to be quantized. + block_size (int, optional): Block size for quantization. Defaults to 32. + scaling_mode (str, optional): Scaling mode for quantization. Defaults to "floor". + + Returns: + torch.Tensor: quantized tensor + torch.Tensor: scales tensor + """ + assert x.ndim == 3, "Input tensor must be 3D" + assert x.dtype in (torch.float32, torch.bfloat16), ( + "Input tensor must be float32 or bfloat16" + ) + q_data, scales = mxfp8_cuda.quantize_3d( + x, scale_dim_n=block_size, scaling_mode=scaling_mode + ) + return q_data, scales + + @mxfp8_quantize_cuda_3d.register_fake + def _fake_mxfp8_quantize_cuda_3d( + x: torch.Tensor, + block_size: int = 32, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 3, "Input tensor must be 3D" + assert x.dtype in (torch.float32, torch.bfloat16), ( + "Input tensor must be float32 or bfloat16" + ) + E, N, K = x.shape + # Quantized tensor is in column major layouts + q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided( + x.shape, (N * K, 1, N) + ) + scales = x.new_empty((E, N // block_size, K), dtype=torch.float8_e8m0fnu) + return q_data, scales + +else: + + def mxfp8_quantize_cuda_3d( + x: torch.Tensor, + block_size: int = 32, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError( + "mxfp8_quantize_cuda_3d is not implemented on this device" + ) diff --git a/torchao/prototype/moe_training/kernels/mxfp8_gemms.py b/torchao/prototype/moe_training/kernels/mxfp8_gemms.py index 4f419f4c6f..fdbc518afa 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_gemms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_gemms.py @@ -2,7 +2,7 @@ import torch -from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( +from torchao.prototype.moe_training.kernels.mxfp8 import ( torch_to_blocked_2d_M_groups, torch_to_blocked_per_group_3d, ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 24c1e6b60d..ab80104d3c 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -16,9 +16,10 @@ triton_fp8_per_group_colwise_scales, triton_fp8_rowwise_3d_transpose_rhs, ) -from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( +from torchao.prototype.moe_training.kernels.mxfp8 import ( compute_blocked_scale_offsets_for_K_groups, compute_blocked_scale_offsets_for_M_groups, + mxfp8_quantize_cuda_3d, triton_mx_block_rearrange_2d_K_groups, triton_mx_block_rearrange_2d_M_groups, triton_mx_block_rearrange_per_group_3d, @@ -354,18 +355,20 @@ def backward(ctx, grad_out: torch.Tensor): grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size ) - # B_data shape: (E, K, N) - # B_scale shape: (E, K, N//block_size) - B_scales_ref, B_data_ref = to_mx( - # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency? - B_t.contiguous(), - elem_dtype=torch.float8_e4m3fn, - block_size=block_size, - ) - - # Experiment with cuda kernel + # Quantize 3d expert weights along N (contraction dimension for next grouped gemm) + # (E, K, N) -> (E, N, K) B = B_t.transpose(-2, -1) - B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size) + E, N, K = B.shape + + # mxfp8_quantize_cuda_3d is only faster for E > 8 + if E > 8: + B_data, B_scales = mxfp8_quantize_cuda_3d( + B._data if hasattr(B, "_data") else B, block_size=block_size + ) + # (E, N//block_size, K) -> (E, K, N//block_size) + B_scales = B_scales.transpose(-2, -1) + else: + B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size) # Convert scales to blocked format for 2d-3d grouped mm grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups( @@ -400,6 +403,7 @@ def backward(ctx, grad_out: torch.Tensor): grad_out_t_scales = grad_out_t_mx._scale_e8m0 # Transpose A so we can scale along the M dimension, then un-transpose. + # A shape: (M, K) # A_t_data shape: (K, M) # A_t_scales shape: (K, M//block_size) A_t_mx = _to_mxfp8_dim1_kernel_wrapper(