4545 HAS_TMA = False
4646 logger .warning (f"Failed to import TMA: { e } " )
4747
48+ HAS_CUDA_129 = (
49+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
50+ )
51+
4852
4953def parse_args (args ):
5054 parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -73,6 +77,10 @@ def get_scaling_recipe(scaling_recipe: str) -> int:
7377 return ScalingType .TensorWise
7478 elif scaling_recipe == "RowWise" :
7579 return ScalingType .RowWise
80+ elif scaling_recipe == "BlockWise1x128" :
81+ return ScalingType .BlockWise1x128
82+ elif scaling_recipe == "BlockWise128x128" :
83+ return ScalingType .BlockWise128x128
7684 else :
7785 raise ValueError (f"Invalid scaling recipe: { scaling_recipe } " )
7886
@@ -89,7 +97,7 @@ def _get_scale_per_tensor(
8997 # For tensor-wise scaling, kernel requires a float32 scale tensor
9098 if custom_scale :
9199 return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
92- scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
100+ scale = ( torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()). reciprocal ()
93101 x *= scale
94102 return x , scale .to (torch .float32 )
95103
@@ -100,22 +108,46 @@ def _get_scale_per_row(
100108 scale = (
101109 torch .finfo (torch .float8_e4m3fn ).max
102110 / x .abs ().max (dim = 0 , keepdim = True ).values
103- )
111+ ). reciprocal ()
104112 else : # scale_a.shape should be [M, 1]
105113 scale = (
106114 torch .finfo (torch .float8_e4m3fn ).max
107115 / x .abs ().max (dim = 1 , keepdim = True ).values
108- )
116+ ). reciprocal ()
109117 x = x .mul (scale )
110118 return x , scale .to (
111119 torch .float32
112120 ) # For row-wise scaling, kernel requires a float32 scale tensor
113121
122+ def _get_scale_per_block (
123+ x : torch .Tensor , block_outer : int , block_inner : int
124+ ) -> (torch .Tensor , torch .Tensor ):
125+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
126+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
127+ scale = (
128+ torch .finfo (torch .float8_e4m3fn ).max / amax
129+ ).reciprocal () # keeps scale small enough such that scaling doesn't cause inf values
130+ x = (
131+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
132+ ) # scale input up to dynamic range of float8_e4m3fn
133+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
134+
135+ if block_outer == 1 and block_inner == 128 :
136+ scale = (
137+ scale .t ().contiguous ().t ()
138+ ) # 1x128 blocks need scales to be outer-dim-major
139+
140+ return x , scale .to (torch .float32 )
141+
114142 match scaling_recipe :
115143 case ScalingType .TensorWise :
116144 return _get_scale_per_tensor (x , custom_scale = custom_scale )
117145 case ScalingType .RowWise :
118146 return _get_scale_per_row (x , transpose = transpose )
147+ case ScalingType .BlockWise1x128 :
148+ return _get_scale_per_block (x , 1 , 128 )
149+ case ScalingType .BlockWise128x128 :
150+ return _get_scale_per_block (x , 128 , 128 )
119151 case _:
120152 raise AssertionError (f"Unsupported scaling type { scaling_recipe } " )
121153
@@ -143,6 +175,19 @@ def __init__(
143175 self .scaling_recipe_a = get_scaling_recipe (scaling_recipe_a )
144176 self .scaling_recipe_b = get_scaling_recipe (scaling_recipe_b )
145177
178+ blockwise_scaling_types = [
179+ ScalingType .BlockWise1x128 ,
180+ ScalingType .BlockWise128x128 ,
181+ ]
182+ self .contains_blockwise_scaling = (
183+ self .scaling_recipe_a in blockwise_scaling_types
184+ or self .scaling_recipe_b in blockwise_scaling_types
185+ )
186+
187+ self .use_fast_accum = (
188+ False if self .contains_blockwise_scaling else True
189+ ) # BlockWise scaled_gemm does not support use_fast_accum=True
190+
146191 def _get_dtype (self ):
147192 if (
148193 self .scaling_recipe_a == ScalingType .TensorWise
@@ -205,12 +250,16 @@ def get_x_val(self, example_inputs) -> float:
205250
206251 @register_benchmark (baseline = True )
207252 def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
253+ assert (
254+ not self .contains_blockwise_scaling or HAS_CUDA_129
255+ ), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"
256+
208257 return lambda : torch ._scaled_mm (
209258 a ,
210259 b .t (),
211260 scale_a ,
212261 scale_b .t (),
213- use_fast_accum = True ,
262+ use_fast_accum = self . use_fast_accum ,
214263 out_dtype = self ._get_dtype (),
215264 )
216265
@@ -227,7 +276,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
227276 b .t (),
228277 scale_a ,
229278 scale_b .t (),
230- use_fast_accum = True ,
279+ use_fast_accum = self . use_fast_accum ,
231280 out_dtype = self ._get_dtype (),
232281 )
233282 compiled = torch .compile (f , dynamic = False )
0 commit comments