@@ -63,13 +63,32 @@ def _get_dtype(self):
6363            return  torch .float16 
6464
6565    def  get_input_iter (self ):
66-         def  _get_scale_per_tensor (x : torch .Tensor , custom_scale : float  =  None ) ->  torch .Tensor :
66+         def  _get_scale_per_tensor (
67+             x : torch .Tensor , custom_scale : float  =  None 
68+         ) ->  torch .Tensor :
6769            # For tensor-wise scaling, kernel requires a float32 scale tensor 
6870            if  custom_scale :
6971                return  torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
7072            scale  =  torch .finfo (torch .float8_e4m3fn ).max  /  x .abs ().max ()
7173            return  scale .to (torch .float32 )
7274
75+         def  _get_scale_per_row (
76+             x : torch .Tensor , transpose : bool  =  False 
77+         ) ->  torch .Tensor :
78+             if  transpose :  # scale_b.shape should be [1, N] 
79+                 scale  =  (
80+                     torch .finfo (torch .float8_e4m3fn ).max 
81+                     /  x .abs ().max (dim = 0 , keepdim = True ).values 
82+                 )
83+             else :  # scale_a.shape should be [M, 1] 
84+                 scale  =  (
85+                     torch .finfo (torch .float8_e4m3fn ).max 
86+                     /  x .abs ().max (dim = 1 , keepdim = True ).values 
87+                 )
88+             return  scale .to (
89+                 torch .float32 
90+             )  # For row-wise scaling, kernel requires a float32 scale tensor 
91+ 
7392        def  args (m , n , k ):
7493            a  =  torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
7594            b  =  (
@@ -80,26 +99,33 @@ def args(m, n, k):
8099            )
81100
82101            if  self .extra_args .scaling_rowwise :
83-                 M , N  =  a .shape [0 ], b .shape [1 ]
84-                 scale_a  =  torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
85-                 scale_b  =  torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
102+                 scale_a  =  _get_scale_per_row (a )
103+                 scale_b  =  _get_scale_per_row (b , transpose = True )
86104            else :
87-                 scale_a  =  _get_scale_per_tensor (a , custom_scale = self .extra_args .per_tensor_scale_a )
88-                 scale_b  =  _get_scale_per_tensor (b , custom_scale = self .extra_args .per_tensor_scale_b )
105+                 scale_a  =  _get_scale_per_tensor (
106+                     a , custom_scale = self .extra_args .per_tensor_scale_a 
107+                 )
108+                 scale_b  =  _get_scale_per_tensor (
109+                     b , custom_scale = self .extra_args .per_tensor_scale_b 
110+                 )
89111
90112            # Kernels expect dtype=float8_e4m3fn 
91113            a  =  a .to (torch .float8_e4m3fn )
92114            b  =  b .to (torch .float8_e4m3fn )
93115
94116            return  (a , b , scale_a , scale_b )
95117
96-         if  hasattr (self , 'external_shapes' ) and  self .external_shapes :  # Check for external shapes loaded from input-loader 
118+         if  (
119+             hasattr (self , "external_shapes" ) and  self .external_shapes 
120+         ):  # Check for external shapes loaded from input-loader 
97121            for  shape  in  self .external_shapes :
98122                if  len (shape ) ==  3 :
99123                    m , n , k  =  shape 
100124                    yield  args (m , n , k )
101125                else :
102-                     logger .warning (f"Skipping invalid shape: { shape }  , expected [M, N, K]" )
126+                     logger .warning (
127+                         f"Skipping invalid shape: { shape }  , expected [M, N, K]" 
128+                     )
103129        elif  self .extra_args .llama :
104130            for  m , n , k , _bias  in  llama_shapes ():
105131                yield  args (m , n , k )
0 commit comments