Skip to content

Commit 1e5d7e3

Browse files
authored
Add fp8_gemm benchmark for deepseek-style scaling
Differential Revision: D83689980 Pull Request resolved: #504
1 parent 474a303 commit 1e5d7e3

File tree

1 file changed

+54
-5
lines changed

1 file changed

+54
-5
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
HAS_TMA = False
4646
logger.warning(f"Failed to import TMA: {e}")
4747

48+
HAS_CUDA_129 = (
49+
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
50+
)
51+
4852

4953
def parse_args(args):
5054
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
@@ -73,6 +77,10 @@ def get_scaling_recipe(scaling_recipe: str) -> int:
7377
return ScalingType.TensorWise
7478
elif scaling_recipe == "RowWise":
7579
return ScalingType.RowWise
80+
elif scaling_recipe == "BlockWise1x128":
81+
return ScalingType.BlockWise1x128
82+
elif scaling_recipe == "BlockWise128x128":
83+
return ScalingType.BlockWise128x128
7684
else:
7785
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")
7886

@@ -89,7 +97,7 @@ def _get_scale_per_tensor(
8997
# For tensor-wise scaling, kernel requires a float32 scale tensor
9098
if custom_scale:
9199
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
92-
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
100+
scale = (torch.finfo(torch.float8_e4m3fn).max / x.abs().max()).reciprocal()
93101
x *= scale
94102
return x, scale.to(torch.float32)
95103

@@ -100,22 +108,46 @@ def _get_scale_per_row(
100108
scale = (
101109
torch.finfo(torch.float8_e4m3fn).max
102110
/ x.abs().max(dim=0, keepdim=True).values
103-
)
111+
).reciprocal()
104112
else: # scale_a.shape should be [M, 1]
105113
scale = (
106114
torch.finfo(torch.float8_e4m3fn).max
107115
/ x.abs().max(dim=1, keepdim=True).values
108-
)
116+
).reciprocal()
109117
x = x.mul(scale)
110118
return x, scale.to(
111119
torch.float32
112120
) # For row-wise scaling, kernel requires a float32 scale tensor
113121

122+
def _get_scale_per_block(
123+
x: torch.Tensor, block_outer: int, block_inner: int
124+
) -> (torch.Tensor, torch.Tensor):
125+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
126+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
127+
scale = (
128+
torch.finfo(torch.float8_e4m3fn).max / amax
129+
).reciprocal() # keeps scale small enough such that scaling doesn't cause inf values
130+
x = (
131+
x.mul(scale).flatten(2, 3).flatten(0, 1)
132+
) # scale input up to dynamic range of float8_e4m3fn
133+
scale = scale.flatten(2, 3).flatten(0, 1)
134+
135+
if block_outer == 1 and block_inner == 128:
136+
scale = (
137+
scale.t().contiguous().t()
138+
) # 1x128 blocks need scales to be outer-dim-major
139+
140+
return x, scale.to(torch.float32)
141+
114142
match scaling_recipe:
115143
case ScalingType.TensorWise:
116144
return _get_scale_per_tensor(x, custom_scale=custom_scale)
117145
case ScalingType.RowWise:
118146
return _get_scale_per_row(x, transpose=transpose)
147+
case ScalingType.BlockWise1x128:
148+
return _get_scale_per_block(x, 1, 128)
149+
case ScalingType.BlockWise128x128:
150+
return _get_scale_per_block(x, 128, 128)
119151
case _:
120152
raise AssertionError(f"Unsupported scaling type {scaling_recipe}")
121153

@@ -143,6 +175,19 @@ def __init__(
143175
self.scaling_recipe_a = get_scaling_recipe(scaling_recipe_a)
144176
self.scaling_recipe_b = get_scaling_recipe(scaling_recipe_b)
145177

178+
blockwise_scaling_types = [
179+
ScalingType.BlockWise1x128,
180+
ScalingType.BlockWise128x128,
181+
]
182+
self.contains_blockwise_scaling = (
183+
self.scaling_recipe_a in blockwise_scaling_types
184+
or self.scaling_recipe_b in blockwise_scaling_types
185+
)
186+
187+
self.use_fast_accum = (
188+
False if self.contains_blockwise_scaling else True
189+
) # BlockWise scaled_gemm does not support use_fast_accum=True
190+
146191
def _get_dtype(self):
147192
if (
148193
self.scaling_recipe_a == ScalingType.TensorWise
@@ -205,12 +250,16 @@ def get_x_val(self, example_inputs) -> float:
205250

206251
@register_benchmark(baseline=True)
207252
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
253+
assert (
254+
not self.contains_blockwise_scaling or HAS_CUDA_129
255+
), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"
256+
208257
return lambda: torch._scaled_mm(
209258
a,
210259
b.t(),
211260
scale_a,
212261
scale_b.t(),
213-
use_fast_accum=True,
262+
use_fast_accum=self.use_fast_accum,
214263
out_dtype=self._get_dtype(),
215264
)
216265

@@ -227,7 +276,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
227276
b.t(),
228277
scale_a,
229278
scale_b.t(),
230-
use_fast_accum=True,
279+
use_fast_accum=self.use_fast_accum,
231280
out_dtype=self._get_dtype(),
232281
)
233282
compiled = torch.compile(f, dynamic=False)

0 commit comments

Comments
 (0)