Skip to content

Commit 2b1b340

Browse files
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise
1 parent 7a948c8 commit 2b1b340

File tree

5 files changed

+583
-25
lines changed

5 files changed

+583
-25
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9):
1313
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1414

15-
1615
from torchao.prototype.moe_training.kernels.float8_rowwise import (
1716
triton_fp8_rowwise_3d_transpose_rhs,
1817
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
@@ -38,8 +37,11 @@
3837
torch_to_float8_per_group_colwise,
3938
torch_to_float8_per_group_rowwise,
4039
)
41-
from torchao.prototype.mx_formats.mx_tensor import to_mx
40+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
4241
from torchao.testing.utils import skip_if_rocm
42+
from torchao.utils import (
43+
is_sm_at_least_100,
44+
)
4345

4446

4547
@skip_if_rocm("ROCm enablement in progress")
@@ -313,3 +315,57 @@ def test_mxfp8_per_group_blocked_scales_2d2d(
313315
output_group_offsets,
314316
)
315317
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"
318+
319+
320+
@pytest.mark.skipif(
321+
not is_sm_at_least_100(),
322+
reason="MXFP8 requires CUDA capability 10.0 or greater",
323+
)
324+
@pytest.mark.parametrize(
325+
"E",
326+
(
327+
1,
328+
2,
329+
),
330+
)
331+
@pytest.mark.parametrize("N", (32, 64))
332+
@pytest.mark.parametrize("K", (32, 64))
333+
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
334+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,))
335+
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
336+
from torchao.prototype import mxfp8_cuda
337+
338+
scaling_mode_str = (
339+
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
340+
)
341+
block_size = 32
342+
343+
# Use disinct incrementing values from 0 to E*M*K-1 to make debugging easier.
344+
x = (
345+
torch.arange(0, E * N * K, dtype=input_dtype, device="cuda")
346+
.reshape(E, N, K)
347+
.contiguous()
348+
)
349+
350+
# Reference implementation
351+
s_d1_ref, y_d1_ref = to_mx(
352+
x.transpose(-2, -1).contiguous(),
353+
elem_dtype=torch.float8_e4m3fn,
354+
block_size=block_size,
355+
)
356+
y_d1_ref = y_d1_ref.transpose(
357+
-2, -1
358+
) # (E, K, N//block_size) -> (E, N//block_size, K)
359+
360+
# CUDA implementation (should work with any stride pattern)
361+
y_d1, s_d1 = mxfp8_cuda.quantize_3d(
362+
x, scale_dim_n=block_size, scaling_mode=scaling_mode_str
363+
)
364+
s_d1 = s_d1.transpose(-2, -1)
365+
366+
# Check scales
367+
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
368+
369+
# Check quantized values
370+
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
371+
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"

torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,72 @@ void mxfp8_quantize_cuda(const torch::Tensor &input,
109109
stream);
110110
}
111111

112+
void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
113+
torch::Tensor &output_colwise,
114+
torch::Tensor &scales_colwise,
115+
int64_t scale_dim_n,
116+
const std::string &fp8_format,
117+
const std::string &scaling_mode) {
118+
119+
// Get tensor properties for 3D tensor (E, N, K)
120+
const int64_t E = input.size(0);
121+
const int64_t N = input.size(1);
122+
const int64_t K = input.size(2);
123+
124+
// Get data pointers
125+
const void *input_ptr = input.data_ptr();
126+
void *output_colwise_ptr = output_colwise.data_ptr();
127+
e8m0_t *scales_colwise_ptr =
128+
reinterpret_cast<e8m0_t *>(scales_colwise.data_ptr());
129+
130+
// Get CUDA stream
131+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
132+
133+
// Get strides of scales tensor
134+
int64_t scales_colwise_stride_dim0 = scales_colwise.stride(0);
135+
int64_t scales_colwise_stride_dim1 = scales_colwise.stride(1);
136+
int64_t scales_colwise_stride_dim2 = scales_colwise.stride(2);
137+
138+
// Get input tensor strides for generic layout support
139+
int64_t input_stride_dim0 = input.stride(0); // E dimension stride
140+
int64_t input_stride_dim1 = input.stride(1); // N dimension stride
141+
int64_t input_stride_dim2 = input.stride(2); // K dimension stride
142+
143+
// Get output tensor strides (shoudl be col major)
144+
int64_t output_stride_dim0 = output_colwise.stride(0); // E dimension stride
145+
int64_t output_stride_dim1 = output_colwise.stride(1); // N dimension stride
146+
int64_t output_stride_dim2 = output_colwise.stride(2); // K dimension stride
147+
148+
149+
#if defined(DEBUG)
150+
printf("mxfp8_quantize_3d_cuda:\n");
151+
printf("Quantizing 3D input tensor of size %ld x %ld x %ld\n", E, N, K);
152+
printf("scaling_mode: %s\n", scaling_mode.c_str());
153+
printf("Scale dim n: %ld\n", scale_dim_n);
154+
printf("Output scale shape: %ld x %ld x %ld\n",
155+
scales_colwise.sizes()[0], scales_colwise.sizes()[1], scales_colwise.sizes()[2]);
156+
printf("scales_colwise_stride_dim0 = %ld\n", scales_colwise_stride_dim0);
157+
printf("scales_colwise_stride_dim1 = %ld\n", scales_colwise_stride_dim1);
158+
printf("input_stride_dim0 = %ld\n", input_stride_dim0);
159+
printf("input_stride_dim1 = %ld\n", input_stride_dim1);
160+
printf("input_stride_dim2 = %ld\n", input_stride_dim2);
161+
printf("output_stride_dim0 = %ld\n", output_stride_dim0);
162+
printf("output_stride_dim1 = %ld\n", output_stride_dim1);
163+
printf("output_stride_dim2 = %ld\n", output_stride_dim2);
164+
#endif
165+
166+
// Call the 3D quantization kernel
167+
MXFP8Quantizer::quantize_3d(input_ptr,
168+
output_colwise_ptr,
169+
scales_colwise_ptr,
170+
E, N, K,
171+
input_stride_dim0, input_stride_dim1, input_stride_dim2,
172+
output_stride_dim0, output_stride_dim1, output_stride_dim2,
173+
scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2,
174+
get_input_dtype(input), get_output_dtype(fp8_format),
175+
scale_dim_n,
176+
get_scaling_mode(scaling_mode),
177+
stream);
178+
}
179+
112180
} // namespace mxfp8

torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ void mxfp8_quantize_cuda(const torch::Tensor &input,
1818
const std::string &fp8_format,
1919
const std::string &scaling_mode);
2020

21+
void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
22+
torch::Tensor &output_colwise,
23+
torch::Tensor &scales_colwise,
24+
int64_t scale_dim_n,
25+
const std::string &fp8_format,
26+
const std::string &scaling_mode);
27+
2128
// Helper for tensor validation
2229
void check_cuda_tensor(const torch::Tensor &t, const char *name) {
2330
TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor");
@@ -115,6 +122,60 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
115122
scales_colwise);
116123
}
117124

125+
// 3D tensor quantization function
126+
std::tuple<torch::Tensor, torch::Tensor>
127+
mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,
128+
const std::string &fp8_format,
129+
const std::string &scaling_mode) {
130+
131+
// Validate inputs
132+
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
133+
// Note: We don't check contiguous for 3D as it may have column major strides
134+
TORCH_CHECK(input.dim() == 3, "input must be 3D");
135+
TORCH_CHECK(input.scalar_type() == torch::kFloat32 ||
136+
input.scalar_type() == torch::kFloat16 ||
137+
input.scalar_type() == torch::kBFloat16,
138+
"Input must be float32, float16, or bfloat16");
139+
TORCH_CHECK(scale_dim_n == 32, "scale_dim_n must be 32 for now");
140+
141+
validate_fp8_format(fp8_format);
142+
143+
const int64_t E = input.size(0);
144+
const int64_t N = input.size(1);
145+
const int64_t K = input.size(2);
146+
147+
// Check dimensions are valid for 3D kernel
148+
TORCH_CHECK((N >= 32) && (N % 32 == 0), "N must be a multiple of 32");
149+
TORCH_CHECK((K >= 32) && (K % 32 == 0), "K must be a multiple of 32");
150+
151+
// The kernel should work with any stride pattern - no layout requirements
152+
153+
c10::cuda::CUDAGuard device_guard(input.device());
154+
155+
// Create tensor options
156+
const auto options_fp8 = torch::TensorOptions()
157+
.dtype(torch::kFloat8_e4m3fn)
158+
.device(input.device());
159+
160+
const auto options_scale = torch::TensorOptions()
161+
.dtype(torch::kFloat8_e8m0fnu)
162+
.device(input.device());
163+
164+
// Create output tensor with column major layout (required for downstream ops)
165+
torch::Tensor output_colwise = torch::empty_strided(
166+
{E, N, K}, {N * K, 1, N}, options_fp8);
167+
168+
// Create scales tensor with shape (E, num_n_blocks, K)
169+
const int64_t num_n_blocks = (N + scale_dim_n - 1) / scale_dim_n;
170+
torch::Tensor scales_colwise = torch::empty({E, num_n_blocks, K}, options_scale);
171+
172+
// Call CUDA kernel
173+
mxfp8_quantize_3d_cuda(input, output_colwise, scales_colwise,
174+
scale_dim_n, fp8_format, scaling_mode);
175+
176+
return std::make_tuple(output_colwise, scales_colwise);
177+
}
178+
118179
} // namespace mxfp8
119180

120181
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -125,4 +186,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
125186
py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32,
126187
py::arg("fp8_format") = "e4m3",
127188
py::arg("scaling_mode") = "floor");
189+
190+
m.def("quantize_3d", &mxfp8::mxfp8_quantize_3d, "MXFP8 3D quantization",
191+
py::arg("input"), py::arg("scale_dim_n") = 32,
192+
py::arg("fp8_format") = "e4m3",
193+
py::arg("scaling_mode") = "floor");
128194
}

0 commit comments

Comments
 (0)