|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py |
| 7 | + |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import List |
| 10 | + |
| 11 | +import torch |
| 12 | +from tabulate import tabulate |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +from benchmarks.utils import benchmark_cuda_function_in_microseconds |
| 16 | +from torchao.prototype import mxfp8_cuda |
| 17 | +from torchao.prototype.moe_training.scaled_grouped_mm import ( |
| 18 | + _to_mxfp8_dim1_3d, |
| 19 | +) |
| 20 | +from torchao.prototype.mx_formats.mx_tensor import to_mx |
| 21 | + |
| 22 | +device = torch.device("cuda") |
| 23 | + |
| 24 | +# Needed since changing args to function causes recompiles |
| 25 | +torch._dynamo.config.cache_size_limit = 1000 |
| 26 | + |
| 27 | + |
| 28 | +@dataclass(frozen=True) |
| 29 | +class ExperimentConfig: |
| 30 | + input_shape: tuple[int] |
| 31 | + |
| 32 | + |
| 33 | +@dataclass(frozen=True) |
| 34 | +class ExperimentResult: |
| 35 | + # time |
| 36 | + to_mx_us: float |
| 37 | + cuda_2d_us: float |
| 38 | + cuda_3d_us: float |
| 39 | + # mem bw |
| 40 | + to_mx_gbps: float |
| 41 | + cuda_2d_gbps: float |
| 42 | + cuda_3d_gbps: float |
| 43 | + |
| 44 | + |
| 45 | +@dataclass(frozen=True) |
| 46 | +class Experiment: |
| 47 | + config: ExperimentConfig |
| 48 | + result: ExperimentResult |
| 49 | + |
| 50 | + |
| 51 | +def get_configs() -> List[ExperimentConfig]: |
| 52 | + # Llama4 shapes. Input activations are scaled along K dim. |
| 53 | + input_shapes = [ |
| 54 | + (1, 8192, 5120), |
| 55 | + (2, 8192, 5120), |
| 56 | + (4, 8192, 5120), |
| 57 | + (8, 8192, 5120), |
| 58 | + (16, 8192, 5120), |
| 59 | + (64, 8192, 5120), |
| 60 | + ] |
| 61 | + configs = [] |
| 62 | + for shape in input_shapes: |
| 63 | + configs.append( |
| 64 | + ExperimentConfig( |
| 65 | + input_shape=shape, |
| 66 | + ) |
| 67 | + ) |
| 68 | + return configs |
| 69 | + |
| 70 | + |
| 71 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 72 | + block_size = 32 |
| 73 | + input_shape = config.input_shape |
| 74 | + input_tensor = torch.randn( |
| 75 | + *input_shape, |
| 76 | + dtype=torch.bfloat16, |
| 77 | + device=device, |
| 78 | + ) |
| 79 | + |
| 80 | + def using_to_mx(x: torch.Tensor) -> torch.Tensor: |
| 81 | + # Reference implementation |
| 82 | + s_d1_ref, y_d1_ref = to_mx( |
| 83 | + # Transpose (E,N,K) to (E,K,N) so N is final dim, |
| 84 | + # since to_mx scales along that dim |
| 85 | + x.transpose(-2, -1).contiguous(), |
| 86 | + elem_dtype=torch.float8_e4m3fn, |
| 87 | + block_size=block_size, |
| 88 | + ) |
| 89 | + |
| 90 | + # Transpose tensors and scales back so we have effectively |
| 91 | + # quantized input shape (E, N, K) along N |
| 92 | + y_d1_ref = y_d1_ref.transpose(-2, -1) |
| 93 | + s_d1_ref = s_d1_ref.transpose(-2, -1) |
| 94 | + return y_d1_ref, s_d1_ref |
| 95 | + |
| 96 | + # bench to_mx |
| 97 | + using_to_mx_c = torch.compile(using_to_mx) |
| 98 | + scales_to_mx, data_to_mx = using_to_mx_c(input_tensor) |
| 99 | + to_mx_time_us = benchmark_cuda_function_in_microseconds( |
| 100 | + using_to_mx_c, |
| 101 | + input_tensor, |
| 102 | + ) |
| 103 | + |
| 104 | + # bench 2d dim1 kernel then transforming to col major |
| 105 | + using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d) |
| 106 | + scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor) |
| 107 | + time_cuda_2d_us = benchmark_cuda_function_in_microseconds( |
| 108 | + using_cuda_2d_c, |
| 109 | + input_tensor, |
| 110 | + ) |
| 111 | + |
| 112 | + # bench 3d cuda kernel |
| 113 | + data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor) |
| 114 | + time_cuda_3d_us = benchmark_cuda_function_in_microseconds( |
| 115 | + mxfp8_cuda.quantize_3d, |
| 116 | + input_tensor, |
| 117 | + ) |
| 118 | + |
| 119 | + # mem bw calculations |
| 120 | + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 |
| 121 | + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 |
| 122 | + bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 |
| 123 | + |
| 124 | + read_bytes = input_tensor.numel() * bytes_per_input_el |
| 125 | + write_bytes = ( |
| 126 | + data_cuda_3d.numel() * bytes_per_output_el |
| 127 | + + scales_cuda_3d.numel() * bytes_per_scale_el |
| 128 | + ) |
| 129 | + |
| 130 | + to_mx_gbps = ((read_bytes + write_bytes) / 1e9) / (to_mx_time_us / 1e6) |
| 131 | + cuda_2d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_2d_us / 1e6) |
| 132 | + cuda_3d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_3d_us / 1e6) |
| 133 | + |
| 134 | + return ExperimentResult( |
| 135 | + # time |
| 136 | + to_mx_us=to_mx_time_us, |
| 137 | + cuda_2d_us=time_cuda_2d_us, |
| 138 | + cuda_3d_us=time_cuda_3d_us, |
| 139 | + # mem bw |
| 140 | + to_mx_gbps=to_mx_gbps, |
| 141 | + cuda_2d_gbps=cuda_2d_gbps, |
| 142 | + cuda_3d_gbps=cuda_3d_gbps, |
| 143 | + ) |
| 144 | + |
| 145 | + |
| 146 | +def print_results(experiments: List[Experiment]): |
| 147 | + headers = [ |
| 148 | + "input_shape", |
| 149 | + "to_mx_us", |
| 150 | + "cuda_2d_us", |
| 151 | + "cuda_3d_us", |
| 152 | + "to_mx_gbps", |
| 153 | + "cuda_2d_gbps", |
| 154 | + "cuda_3d_gbps", |
| 155 | + ] |
| 156 | + rows = [] |
| 157 | + for experiment in experiments: |
| 158 | + rows.append( |
| 159 | + [ |
| 160 | + str(experiment.config.input_shape), |
| 161 | + experiment.result.to_mx_us, |
| 162 | + experiment.result.cuda_2d_us, |
| 163 | + experiment.result.cuda_3d_us, |
| 164 | + round(experiment.result.to_mx_gbps, 3), |
| 165 | + round(experiment.result.cuda_2d_gbps, 3), |
| 166 | + round(experiment.result.cuda_3d_gbps, 3), |
| 167 | + ] |
| 168 | + ) |
| 169 | + print(tabulate(rows, headers=headers)) |
| 170 | + |
| 171 | + |
| 172 | +def main(): |
| 173 | + torch.random.manual_seed(123) |
| 174 | + configs = get_configs() |
| 175 | + results = [] |
| 176 | + for config in tqdm(configs): |
| 177 | + result = run_experiment(config) |
| 178 | + results.append(Experiment(config=config, result=result)) |
| 179 | + |
| 180 | + # Use Tabulate to print results |
| 181 | + print_results(results) |
| 182 | + |
| 183 | + |
| 184 | +if __name__ == "__main__": |
| 185 | + main() |
0 commit comments