Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
from torchao.prototype.moe_training.kernels.mxfp8 import (
torch_to_blocked_2d_M_groups,
torch_to_blocked_per_group_3d,
)
Expand Down
96 changes: 59 additions & 37 deletions benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000

# Workaround for https://github.com/pytorch/ao/pull/2990#issuecomment-3285762681
torch._dynamo.config.automatic_dynamic_shapes = False


@dataclass(frozen=True)
class ExperimentConfig:
high_precision_dtype: torch.dtype
A_shape: tuple[int]
B_shape: tuple[int]
MNKG: tuple[int]
recipe: MoEScalingType


@dataclass(frozen=True)
class ExperimentResult:
bf16_e2e_us: float
scaled_e2e_us: float
scaled_e2e_speedup: float
bf16_fwd_bwd_us: float
scaled_fwd_bwd_us: float
scaled_fwd_bwd_speedup: float
bf16_fwd_us: float
scaled_fwd_us: float
scaled_fwd_speedup: float
Expand All @@ -54,22 +56,46 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes
A_shapes = [(16640, 5120)]
B_shapes = [(1, 8192, 5120), (4, 8192, 5120), (16, 8192, 5120), (64, 8192, 5120)]
MNKG_list = [
# Llama4 16e with various experts per device (i.e., different EP degrees)
(16640, 8192, 5120, 1),
(16640, 8192, 5120, 2),
(16640, 8192, 5120, 4),
(16640, 8192, 5120, 8),
# (16640, 5120, 8192, 1),
# (16640, 5120, 8192, 2),
# (16640, 5120, 8192, 4),
# (16640, 5120, 8192, 8),
# # DSV3 236B with various experts per device (i.e., different EP degrees)
# (16640, 5120, 1536, 1),
# (16640, 5120, 1536, 2),
# (16640, 5120, 1536, 4),
# (16640, 5120, 1536, 8),
# (16640, 1536, 5120, 1),
# (16640, 1536, 5120, 2),
# (16640, 1536, 5120, 4),
# (16640, 1536, 5120, 8),
# # DSV3 671B with various experts per device (i.e., different EP degrees)
# (16640, 7168, 2048, 1),
# (16640, 7168, 2048, 2),
# (16640, 7168, 2048, 4),
# (16640, 7168, 2048, 8),
# (16640, 2048, 7168, 1),
# (16640, 2048, 7168, 2),
# (16640, 2048, 7168, 4),
# (16640, 2048, 7168, 8),
]
recipes = [MoEScalingType.FP8_ROWWISE, MoEScalingType.MXFP8]
high_precision_dtypes = [torch.bfloat16]
configs = []
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(
A_shapes,
B_shapes,
for MNKG, recipe, high_precision_dtype in itertools.product(
MNKG_list,
recipes,
high_precision_dtypes,
):
configs.append(
ExperimentConfig(
A_shape=A_shape,
B_shape=B_shape,
MNKG=MNKG,
recipe=recipe,
high_precision_dtype=high_precision_dtype,
)
Expand All @@ -80,15 +106,17 @@ def get_configs() -> List[ExperimentConfig]:
def run_experiment(
config: ExperimentConfig, args: argparse.Namespace
) -> ExperimentResult:
total_M, N, K, G = config.MNKG

# define test inputs
A = torch.randn(
*config.A_shape,
(total_M, K),
dtype=config.high_precision_dtype,
device=device,
requires_grad=True,
)
B_t = torch.randn(
*config.B_shape,
(G, N, K),
dtype=config.high_precision_dtype,
device=device,
requires_grad=True,
Expand All @@ -99,17 +127,15 @@ def run_experiment(
# that occurs in the backward pass of the differentiable scaled grouped mm.
# - the transposed tensor in col-major format with groups along the row dimension,
# which represents the right operand.
n_groups = config.B_shape[0]
Mg = A.shape[0]
token_group_alignment_size = 32 if config.recipe == MoEScalingType.MXFP8 else 16
offs = generate_jagged_offs(n_groups, Mg, multiple_of=token_group_alignment_size)
offs = generate_jagged_offs(G, total_M, multiple_of=token_group_alignment_size)

labels = torch.ones(
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
)

# E2E bf16 benchmark + profiling
bf16_e2e_us = bench_fwd_bwd_microseconds(
# fwd_bwd bf16 benchmark + profiling
bf16_fwd_bwd_us = bench_fwd_bwd_microseconds(
torch._grouped_mm,
A,
B_t,
Expand All @@ -130,8 +156,8 @@ def run_experiment(
profile_name="bf16_profile",
)

# E2E scaled benchmark + profiling
scaled_e2e_us = bench_fwd_bwd_microseconds(
# fwd_bwd scaled benchmark + profiling
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
_scaled_grouped_mm,
A,
B_t,
Expand Down Expand Up @@ -174,9 +200,9 @@ def run_experiment(
)

return ExperimentResult(
bf16_e2e_us=round(bf16_e2e_us, 3),
scaled_e2e_us=round(scaled_e2e_us, 3),
scaled_e2e_speedup=round(bf16_e2e_us / scaled_e2e_us, 3),
bf16_fwd_bwd_us=round(bf16_fwd_bwd_us, 3),
scaled_fwd_bwd_us=round(scaled_fwd_bwd_us, 3),
scaled_fwd_bwd_speedup=round(bf16_fwd_bwd_us / scaled_fwd_bwd_us, 3),
bf16_fwd_us=round(bf16_fwd_us, 3),
scaled_fwd_us=round(scaled_fwd_us, 3),
scaled_fwd_speedup=round(bf16_fwd_us / scaled_fwd_us, 3),
Expand All @@ -185,28 +211,24 @@ def run_experiment(

def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"B_shape",
"M,N,K,G",
"recipe",
"bf16_e2e_us",
"scaled_e2e_us",
"scaled_e2e_speedup",
"bf16_fwd_bwd_us",
"scaled_fwd_bwd_us",
"scaled_fwd_bwd_speedup",
"bf16_fwd_us",
"scaled_fwd_us",
"scaled_fwd_speedup",
]
rows = []
for experiment in experiments:
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
B_shape = f"({experiment.config.B_shape[0]}, {experiment.config.B_shape[1]}, {experiment.config.B_shape[2]})"
rows.append(
[
A_shape,
B_shape,
str(experiment.config.MNKG),
experiment.config.recipe,
experiment.result.bf16_e2e_us,
experiment.result.scaled_e2e_us,
f"{experiment.result.scaled_e2e_speedup}x",
experiment.result.bf16_fwd_bwd_us,
experiment.result.scaled_fwd_bwd_us,
f"{experiment.result.scaled_fwd_bwd_speedup}x",
experiment.result.bf16_fwd_us,
experiment.result.scaled_fwd_us,
f"{experiment.result.scaled_fwd_speedup}x",
Expand Down
185 changes: 185 additions & 0 deletions benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_3d
from torchao.prototype.moe_training.scaled_grouped_mm import (
_to_mxfp8_dim1_3d,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
# time
to_mx_us: float
cuda_2d_us: float
cuda_3d_us: float
# mem bw
to_mx_gbps: float
cuda_2d_gbps: float
cuda_3d_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes. Input activations are scaled along K dim.
input_shapes = [
(1, 8192, 5120),
(2, 8192, 5120),
(4, 8192, 5120),
(8, 8192, 5120),
(16, 8192, 5120),
(64, 8192, 5120),
]
configs = []
for shape in input_shapes:
configs.append(
ExperimentConfig(
input_shape=shape,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
block_size = 32
input_shape = config.input_shape
input_tensor = torch.randn(
*input_shape,
dtype=torch.bfloat16,
device=device,
)

def using_to_mx(x: torch.Tensor) -> torch.Tensor:
# Reference implementation
s_d1_ref, y_d1_ref = to_mx(
# Transpose (E,N,K) to (E,K,N) so N is final dim,
# since to_mx scales along that dim
x.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Transpose tensors and scales back so we have effectively
# quantized input shape (E, N, K) along N
y_d1_ref = y_d1_ref.transpose(-2, -1)
s_d1_ref = s_d1_ref.transpose(-2, -1)
return y_d1_ref, s_d1_ref

# bench to_mx
using_to_mx_c = torch.compile(using_to_mx)
scales_to_mx, data_to_mx = using_to_mx_c(input_tensor)
to_mx_time_us = benchmark_cuda_function_in_microseconds(
using_to_mx_c,
input_tensor,
)

# bench 2d dim1 kernel then transforming to col major
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor)
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
using_cuda_2d_c,
input_tensor,
)

# bench 3d cuda kernel
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor)
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
mxfp8_quantize_cuda_3d,
input_tensor,
)

# mem bw calculations
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = (
data_cuda_3d.numel() * bytes_per_output_el
+ scales_cuda_3d.numel() * bytes_per_scale_el
)

to_mx_gbps = ((read_bytes + write_bytes) / 1e9) / (to_mx_time_us / 1e6)
cuda_2d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_2d_us / 1e6)
cuda_3d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_3d_us / 1e6)

return ExperimentResult(
# time
to_mx_us=to_mx_time_us,
cuda_2d_us=time_cuda_2d_us,
cuda_3d_us=time_cuda_3d_us,
# mem bw
to_mx_gbps=to_mx_gbps,
cuda_2d_gbps=cuda_2d_gbps,
cuda_3d_gbps=cuda_3d_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"to_mx_us",
"cuda_2d_us",
"cuda_3d_us",
"to_mx_gbps",
"cuda_2d_gbps",
"cuda_3d_gbps",
]
rows = []
for experiment in experiments:
rows.append(
[
str(experiment.config.input_shape),
experiment.result.to_mx_us,
experiment.result.cuda_2d_us,
experiment.result.cuda_3d_us,
round(experiment.result.to_mx_gbps, 3),
round(experiment.result.cuda_2d_gbps, 3),
round(experiment.result.cuda_3d_gbps, 3),
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
Loading
Loading