Skip to content

Commit 6403a25

Browse files
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
1 parent 213a554 commit 6403a25

File tree

6 files changed

+752
-5
lines changed

6 files changed

+752
-5
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype import mxfp8_cuda
17+
from torchao.prototype.moe_training.scaled_grouped_mm import (
18+
_to_mxfp8_dim1_3d,
19+
)
20+
from torchao.prototype.mx_formats.mx_tensor import to_mx
21+
22+
device = torch.device("cuda")
23+
24+
# Needed since changing args to function causes recompiles
25+
torch._dynamo.config.cache_size_limit = 1000
26+
27+
28+
@dataclass(frozen=True)
29+
class ExperimentConfig:
30+
input_shape: tuple[int]
31+
32+
33+
@dataclass(frozen=True)
34+
class ExperimentResult:
35+
# time
36+
to_mx_us: float
37+
cuda_2d_us: float
38+
cuda_3d_us: float
39+
# mem bw
40+
to_mx_gbps: float
41+
cuda_2d_gbps: float
42+
cuda_3d_gbps: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
# Llama4 shapes. Input activations are scaled along K dim.
53+
input_shapes = [
54+
(1, 8192, 5120),
55+
(2, 8192, 5120),
56+
(4, 8192, 5120),
57+
(8, 8192, 5120),
58+
(16, 8192, 5120),
59+
(64, 8192, 5120),
60+
]
61+
configs = []
62+
for shape in input_shapes:
63+
configs.append(
64+
ExperimentConfig(
65+
input_shape=shape,
66+
)
67+
)
68+
return configs
69+
70+
71+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
72+
block_size = 32
73+
input_shape = config.input_shape
74+
input_tensor = torch.randn(
75+
*input_shape,
76+
dtype=torch.bfloat16,
77+
device=device,
78+
)
79+
80+
def using_to_mx(x: torch.Tensor) -> torch.Tensor:
81+
# Reference implementation
82+
s_d1_ref, y_d1_ref = to_mx(
83+
# Transpose (E,N,K) to (E,K,N) so N is final dim,
84+
# since to_mx scales along that dim
85+
x.transpose(-2, -1).contiguous(),
86+
elem_dtype=torch.float8_e4m3fn,
87+
block_size=block_size,
88+
)
89+
90+
# Transpose tensors and scales back so we have effectively
91+
# quantized input shape (E, N, K) along N
92+
y_d1_ref = y_d1_ref.transpose(-2, -1)
93+
s_d1_ref = s_d1_ref.transpose(-2, -1)
94+
return y_d1_ref, s_d1_ref
95+
96+
# bench to_mx
97+
using_to_mx_c = torch.compile(using_to_mx)
98+
scales_to_mx, data_to_mx = using_to_mx_c(input_tensor)
99+
to_mx_time_us = benchmark_cuda_function_in_microseconds(
100+
using_to_mx_c,
101+
input_tensor,
102+
)
103+
104+
# bench 2d dim1 kernel then transforming to col major
105+
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
106+
scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor)
107+
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
108+
using_cuda_2d_c,
109+
input_tensor,
110+
)
111+
112+
# bench 3d cuda kernel
113+
data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor)
114+
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
115+
mxfp8_cuda.quantize_3d,
116+
input_tensor,
117+
)
118+
119+
# mem bw calculations
120+
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
121+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
122+
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
123+
124+
read_bytes = input_tensor.numel() * bytes_per_input_el
125+
write_bytes = (
126+
data_cuda_3d.numel() * bytes_per_output_el
127+
+ scales_cuda_3d.numel() * bytes_per_scale_el
128+
)
129+
130+
to_mx_gbps = ((read_bytes + write_bytes) / 1e9) / (to_mx_time_us / 1e6)
131+
cuda_2d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_2d_us / 1e6)
132+
cuda_3d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_3d_us / 1e6)
133+
134+
return ExperimentResult(
135+
# time
136+
to_mx_us=to_mx_time_us,
137+
cuda_2d_us=time_cuda_2d_us,
138+
cuda_3d_us=time_cuda_3d_us,
139+
# mem bw
140+
to_mx_gbps=to_mx_gbps,
141+
cuda_2d_gbps=cuda_2d_gbps,
142+
cuda_3d_gbps=cuda_3d_gbps,
143+
)
144+
145+
146+
def print_results(experiments: List[Experiment]):
147+
headers = [
148+
"input_shape",
149+
"to_mx_us",
150+
"cuda_2d_us",
151+
"cuda_3d_us",
152+
"to_mx_gbps",
153+
"cuda_2d_gbps",
154+
"cuda_3d_gbps",
155+
]
156+
rows = []
157+
for experiment in experiments:
158+
rows.append(
159+
[
160+
str(experiment.config.input_shape),
161+
experiment.result.to_mx_us,
162+
experiment.result.cuda_2d_us,
163+
experiment.result.cuda_3d_us,
164+
round(experiment.result.to_mx_gbps, 3),
165+
round(experiment.result.cuda_2d_gbps, 3),
166+
round(experiment.result.cuda_3d_gbps, 3),
167+
]
168+
)
169+
print(tabulate(rows, headers=headers))
170+
171+
172+
def main():
173+
torch.random.manual_seed(123)
174+
configs = get_configs()
175+
results = []
176+
for config in tqdm(configs):
177+
result = run_experiment(config)
178+
results.append(Experiment(config=config, result=result))
179+
180+
# Use Tabulate to print results
181+
print_results(results)
182+
183+
184+
if __name__ == "__main__":
185+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9):
1313
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1414

15-
1615
from torchao.prototype.moe_training.kernels.float8_rowwise import (
1716
triton_fp8_rowwise_3d_transpose_rhs,
1817
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
@@ -38,8 +37,11 @@
3837
torch_to_float8_per_group_colwise,
3938
torch_to_float8_per_group_rowwise,
4039
)
41-
from torchao.prototype.mx_formats.mx_tensor import to_mx
40+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
4241
from torchao.testing.utils import skip_if_rocm
42+
from torchao.utils import (
43+
is_sm_at_least_100,
44+
)
4345

4446

4547
@skip_if_rocm("ROCm enablement in progress")
@@ -316,3 +318,53 @@ def test_triton_mx_block_rearrange_2d_K_groups(
316318
output_group_offsets,
317319
)
318320
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"
321+
322+
323+
@pytest.mark.skipif(
324+
not is_sm_at_least_100(),
325+
reason="MXFP8 requires CUDA capability 10.0 or greater",
326+
)
327+
@pytest.mark.parametrize("E", (1, 2, 4, 8))
328+
@pytest.mark.parametrize("N", (32, 64, 8192))
329+
@pytest.mark.parametrize("K", (32, 64, 8192))
330+
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
331+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,))
332+
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
333+
from torchao.prototype import mxfp8_cuda
334+
335+
scaling_mode_str = (
336+
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
337+
)
338+
block_size = 32
339+
340+
# Use disinct incrementing values from 0 to E*M*K-1 to make debugging easier.
341+
x = (
342+
torch.arange(0, E * N * K, dtype=input_dtype, device="cuda")
343+
.reshape(E, N, K)
344+
.contiguous()
345+
)
346+
347+
# Reference implementation
348+
s_d1_ref, y_d1_ref = to_mx(
349+
# Transpose so N is final dim, since to_mx scales along that dim
350+
x.transpose(-2, -1).contiguous(),
351+
elem_dtype=torch.float8_e4m3fn,
352+
block_size=block_size,
353+
)
354+
355+
# Transpose tensors and scales back so we have effectively
356+
# quantized input shape (E, N, K) along N
357+
y_d1_ref = y_d1_ref.transpose(-2, -1)
358+
s_d1_ref = s_d1_ref.transpose(-2, -1)
359+
360+
# CUDA implementation (should work with any stride pattern)
361+
y_d1, s_d1 = mxfp8_cuda.quantize_3d(
362+
x, scale_dim_n=block_size, scaling_mode=scaling_mode_str
363+
)
364+
365+
# Check scales
366+
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
367+
368+
# Check quantized values
369+
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
370+
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"

torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,72 @@ void mxfp8_quantize_cuda(const torch::Tensor &input,
109109
stream);
110110
}
111111

112+
void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
113+
torch::Tensor &output_colwise,
114+
torch::Tensor &scales_colwise,
115+
int64_t scale_dim_n,
116+
const std::string &fp8_format,
117+
const std::string &scaling_mode) {
118+
119+
// Get tensor properties for 3D tensor (E, N, K)
120+
const int64_t E = input.size(0);
121+
const int64_t N = input.size(1);
122+
const int64_t K = input.size(2);
123+
124+
// Get data pointers
125+
const void *input_ptr = input.data_ptr();
126+
void *output_colwise_ptr = output_colwise.data_ptr();
127+
e8m0_t *scales_colwise_ptr =
128+
reinterpret_cast<e8m0_t *>(scales_colwise.data_ptr());
129+
130+
// Get CUDA stream
131+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
132+
133+
// Get strides of scales tensor
134+
int64_t scales_colwise_stride_dim0 = scales_colwise.stride(0);
135+
int64_t scales_colwise_stride_dim1 = scales_colwise.stride(1);
136+
int64_t scales_colwise_stride_dim2 = scales_colwise.stride(2);
137+
138+
// Get input tensor strides for generic layout support
139+
int64_t input_stride_dim0 = input.stride(0); // E dimension stride
140+
int64_t input_stride_dim1 = input.stride(1); // N dimension stride
141+
int64_t input_stride_dim2 = input.stride(2); // K dimension stride
142+
143+
// Get output tensor strides (shoudl be col major)
144+
int64_t output_stride_dim0 = output_colwise.stride(0); // E dimension stride
145+
int64_t output_stride_dim1 = output_colwise.stride(1); // N dimension stride
146+
int64_t output_stride_dim2 = output_colwise.stride(2); // K dimension stride
147+
148+
149+
#if defined(DEBUG)
150+
printf("mxfp8_quantize_3d_cuda:\n");
151+
printf("Quantizing 3D input tensor of size %ld x %ld x %ld\n", E, N, K);
152+
printf("scaling_mode: %s\n", scaling_mode.c_str());
153+
printf("Scale dim n: %ld\n", scale_dim_n);
154+
printf("Output scale shape: %ld x %ld x %ld\n",
155+
scales_colwise.sizes()[0], scales_colwise.sizes()[1], scales_colwise.sizes()[2]);
156+
printf("scales_colwise_stride_dim0 = %ld\n", scales_colwise_stride_dim0);
157+
printf("scales_colwise_stride_dim1 = %ld\n", scales_colwise_stride_dim1);
158+
printf("input_stride_dim0 = %ld\n", input_stride_dim0);
159+
printf("input_stride_dim1 = %ld\n", input_stride_dim1);
160+
printf("input_stride_dim2 = %ld\n", input_stride_dim2);
161+
printf("output_stride_dim0 = %ld\n", output_stride_dim0);
162+
printf("output_stride_dim1 = %ld\n", output_stride_dim1);
163+
printf("output_stride_dim2 = %ld\n", output_stride_dim2);
164+
#endif
165+
166+
// Call the 3D quantization kernel
167+
MXFP8Quantizer::quantize_3d(input_ptr,
168+
output_colwise_ptr,
169+
scales_colwise_ptr,
170+
E, N, K,
171+
input_stride_dim0, input_stride_dim1, input_stride_dim2,
172+
output_stride_dim0, output_stride_dim1, output_stride_dim2,
173+
scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2,
174+
get_input_dtype(input), get_output_dtype(fp8_format),
175+
scale_dim_n,
176+
get_scaling_mode(scaling_mode),
177+
stream);
178+
}
179+
112180
} // namespace mxfp8

0 commit comments

Comments
 (0)