Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ def fwd_bwd(*args, **kwargs):

def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs):
fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn

def inference_fn(*args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? forward in training is not run with torch.inference_mode()

with torch.inference_mode():
return fn_compiled(*args, **kwargs)

return benchmark_cuda_function_in_microseconds(
fn_compiled,
inference_fn,
*args,
**kwargs,
)
Expand Down
64 changes: 54 additions & 10 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,13 @@ def _scaled_grouped_mm(
"""
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
if scaling_type == MoEScalingType.FP8_ROWWISE:
logger.debug("Using fp8 rowwise for _scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
offs,
out_dtype,
)
elif scaling_type == MoEScalingType.MXFP8:
logger.debug("Using mxfp8 for _scaled_grouped_mm")
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
return _MXFP8GroupedMM.apply(
A,
Expand Down Expand Up @@ -358,13 +356,17 @@ def backward(ctx, grad_out: torch.Tensor):

# B_data shape: (E, K, N)
# B_scale shape: (E, K, N//block_size)
B_scales, B_data = to_mx(
B_scales_ref, B_data_ref = to_mx(
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
B_t.contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Experiment with cuda kernel
B = B_t.transpose(-2, -1)
B_scales, B_data = _to_mxfp8_dim1_3d(B, block_size=block_size)

# Convert scales to blocked format for 2d-3d grouped mm
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
grad_out_scale,
Expand All @@ -376,21 +378,26 @@ def backward(ctx, grad_out: torch.Tensor):
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
grad_A = torch._scaled_grouped_mm(
grad_out_data,
B_data.transpose(-2, -1),
B_data,
grad_out_scales_blocked,
B_scales_blocked,
offs=offs,
out_dtype=out_dtype,
)

# grad_out_t_data shape: (N, M)
# grad_out_t_data shape: (M, N)
# grad_out_t_scales shape: (N, M//block_size)
grad_out_t_scales, grad_out_t_data = to_mx(
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
grad_out.transpose(-2, -1).contiguous(),
grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper(
grad_out,
block_size,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
hp_dtype=grad_out.dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
scale_calculation_mode=ScaleCalculationMode.FLOOR,
)
grad_out_t_data = grad_out_t_mx.qdata
grad_out_t_scales = grad_out_t_mx._scale_e8m0

# Transpose A so we can scale along the M dimension, then un-transpose.
# A_t_data shape: (K, M)
Expand All @@ -412,7 +419,6 @@ def backward(ctx, grad_out: torch.Tensor):
_, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups(
scale_group_offsets
)

grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
grad_out_t_scales,
scale_group_offsets,
Expand All @@ -438,6 +444,40 @@ def backward(ctx, grad_out: torch.Tensor):
return grad_A, grad_B_t, None, None, None


def _to_mxfp8_dim1_3d(
B: torch.Tensor,
block_size: int = 32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity.
"""
E, N, K = B.shape
B_reshaped = B.reshape(E * N, K)
B_t_mx = _to_mxfp8_dim1_kernel_wrapper(
B_reshaped,
block_size,
elem_dtype=torch.float8_e4m3fn,
hp_dtype=B_reshaped.dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
scale_calculation_mode=ScaleCalculationMode.FLOOR,
)
B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K)
B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K)
B_scales = B_t_mx._scale_e8m0.view(torch.uint8) # (K, E*N//block_size)
B_scales = B_scales.reshape(
K, E, N // block_size
) # (K, E*N//block_size) -> (K, E, N//block_size)
B_scales = B_scales.permute(
1, 0, 2
) # (K, E, N//block_size) -> (E, K, N//block_size)
B_scales = B_scales.view(torch.float8_e8m0fnu)

# TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major
B_data = B_data.transpose(-2, -1).contiguous().transpose(-2, -1)
return B_scales, B_data


def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
A_data: torch.Tensor,
A_scale: torch.Tensor,
Expand Down Expand Up @@ -606,3 +646,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
# Perform bf16 grouped GEMM using dequantized A and B.
out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype)
return out


def round_up(x, y):
return ((x + y - 1) // y) * y
Loading