Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
torch_to_blocked_2d_M_groups,
torch_to_blocked_per_group_3d,
)
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype import mxfp8_cuda
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_3d
from torchao.prototype.moe_training.scaled_grouped_mm import (
_to_mxfp8_dim1_3d,
)
Expand Down Expand Up @@ -110,9 +110,9 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
)

# bench 3d cuda kernel
data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor)
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor)
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
mxfp8_cuda.quantize_3d,
mxfp8_quantize_cuda_3d,
input_tensor,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
compute_blocked_scale_offsets_for_M_groups,
torch_to_blocked_2d_M_groups,
triton_mx_block_rearrange_2d_M_groups,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
torch_to_blocked_per_group_3d,
triton_mx_block_rearrange_per_group_3d,
)
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
triton_fp8_per_group_colwise_scales,
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
compute_blocked_scale_offsets_for_K_groups,
compute_blocked_scale_offsets_for_M_groups,
torch_to_blocked_2d_K_groups,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Tuple

import torch
Expand All @@ -7,7 +8,10 @@
from torch.library import triton_op, wrap_triton

from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import ceil_div
from torchao.utils import (
ceil_div,
is_sm_at_least_100,
)


def torch_to_blocked_2d_M_groups(
Expand Down Expand Up @@ -645,3 +649,77 @@ def _dest_indices_for_block(
# Flatten
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
return dest_indices_flat


mxfp8_cuda_extension_available = False
if is_sm_at_least_100():
try:
# MXFP8 CUDA kernel is only built on SM100+. Furthermore,
# currently our CI runners are not SM100+, so the user needs to build
# from source.
# TODO(#2932): improve this
from torchao.prototype import mxfp8_cuda

mxfp8_cuda_extension_available = True
except ImportError:
logging.debug("Skipping import of torchao.prototype.mxfp8_cuda")

if mxfp8_cuda_extension_available:
# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
# Currently we have to use an arbitrary string because custom ops don't support enum
# params.
@torch.library.custom_op("torchao::mxfp8_quantize_cuda_3d", mutates_args=())
def mxfp8_quantize_cuda_3d(
x: torch.Tensor,
block_size: int = 32,
scaling_mode: str = "floor",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes a 3D tensor of shape (E,N,K) to MXFP8 format, scaling along N.

Args:
x (torch.Tensor): Input tensor to be quantized.
block_size (int, optional): Block size for quantization. Defaults to 32.
scaling_mode (str, optional): Scaling mode for quantization. Defaults to "floor".

Returns:
torch.Tensor: quantized tensor
torch.Tensor: scales tensor
"""
assert x.ndim == 3, "Input tensor must be 3D"
assert x.dtype in (torch.float32, torch.bfloat16), (
"Input tensor must be float32 or bfloat16"
)
q_data, scales = mxfp8_cuda.quantize_3d(
x, scale_dim_n=block_size, scaling_mode=scaling_mode
)
return q_data, scales

@mxfp8_quantize_cuda_3d.register_fake
def _fake_mxfp8_quantize_cuda_3d(
x: torch.Tensor,
block_size: int = 32,
scaling_mode: str = "floor",
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.ndim == 3, "Input tensor must be 3D"
assert x.dtype in (torch.float32, torch.bfloat16), (
"Input tensor must be float32 or bfloat16"
)
E, N, K = x.shape
# Quantized tensor is in column major layouts
q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided(
x.shape, (N * K, 1, N)
)
scales = x.new_empty((E, N // block_size, K), dtype=torch.float8_e8m0fnu)
return q_data, scales

else:

def mxfp8_quantize_cuda_3d(
x: torch.Tensor,
block_size: int = 32,
scaling_mode: str = "floor",
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError(
"mxfp8_quantize_cuda_3d is not implemented on this device"
)
2 changes: 1 addition & 1 deletion torchao/prototype/moe_training/kernels/mxfp8_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
torch_to_blocked_2d_M_groups,
torch_to_blocked_per_group_3d,
)
Expand Down
28 changes: 16 additions & 12 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
triton_fp8_per_group_colwise_scales,
triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
compute_blocked_scale_offsets_for_K_groups,
compute_blocked_scale_offsets_for_M_groups,
mxfp8_quantize_cuda_3d,
triton_mx_block_rearrange_2d_K_groups,
triton_mx_block_rearrange_2d_M_groups,
triton_mx_block_rearrange_per_group_3d,
Expand Down Expand Up @@ -354,18 +355,20 @@ def backward(ctx, grad_out: torch.Tensor):
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)

# B_data shape: (E, K, N)
# B_scale shape: (E, K, N//block_size)
B_scales_ref, B_data_ref = to_mx(
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
B_t.contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Experiment with cuda kernel
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
# (E, K, N) -> (E, N, K)
B = B_t.transpose(-2, -1)
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)
E, N, K = B.shape

# mxfp8_quantize_cuda_3d is only faster for E > 8
if E > 8:
B_data, B_scales = mxfp8_quantize_cuda_3d(
B._data if hasattr(B, "_data") else B, block_size=block_size
)
# (E, N//block_size, K) -> (E, K, N//block_size)
B_scales = B_scales.transpose(-2, -1)
else:
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)

# Convert scales to blocked format for 2d-3d grouped mm
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
Expand Down Expand Up @@ -400,6 +403,7 @@ def backward(ctx, grad_out: torch.Tensor):
grad_out_t_scales = grad_out_t_mx._scale_e8m0

# Transpose A so we can scale along the M dimension, then un-transpose.
# A shape: (M, K)
# A_t_data shape: (K, M)
# A_t_scales shape: (K, M//block_size)
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(
Expand Down