@@ -63,13 +63,32 @@ def _get_dtype(self):
63
63
return torch .float16
64
64
65
65
def get_input_iter (self ):
66
- def _get_scale_per_tensor (x : torch .Tensor , custom_scale : float = None ) -> torch .Tensor :
66
+ def _get_scale_per_tensor (
67
+ x : torch .Tensor , custom_scale : float = None
68
+ ) -> torch .Tensor :
67
69
# For tensor-wise scaling, kernel requires a float32 scale tensor
68
70
if custom_scale :
69
71
return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
70
72
scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
71
73
return scale .to (torch .float32 )
72
74
75
+ def _get_scale_per_row (
76
+ x : torch .Tensor , transpose : bool = False
77
+ ) -> torch .Tensor :
78
+ if transpose : # scale_b.shape should be [1, N]
79
+ scale = (
80
+ torch .finfo (torch .float8_e4m3fn ).max
81
+ / x .abs ().max (dim = 0 , keepdim = True ).values
82
+ )
83
+ else : # scale_a.shape should be [M, 1]
84
+ scale = (
85
+ torch .finfo (torch .float8_e4m3fn ).max
86
+ / x .abs ().max (dim = 1 , keepdim = True ).values
87
+ )
88
+ return scale .to (
89
+ torch .float32
90
+ ) # For row-wise scaling, kernel requires a float32 scale tensor
91
+
73
92
def args (m , n , k ):
74
93
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
75
94
b = (
@@ -80,26 +99,33 @@ def args(m, n, k):
80
99
)
81
100
82
101
if self .extra_args .scaling_rowwise :
83
- M , N = a .shape [0 ], b .shape [1 ]
84
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
85
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
102
+ scale_a = _get_scale_per_row (a )
103
+ scale_b = _get_scale_per_row (b , transpose = True )
86
104
else :
87
- scale_a = _get_scale_per_tensor (a , custom_scale = self .extra_args .per_tensor_scale_a )
88
- scale_b = _get_scale_per_tensor (b , custom_scale = self .extra_args .per_tensor_scale_b )
105
+ scale_a = _get_scale_per_tensor (
106
+ a , custom_scale = self .extra_args .per_tensor_scale_a
107
+ )
108
+ scale_b = _get_scale_per_tensor (
109
+ b , custom_scale = self .extra_args .per_tensor_scale_b
110
+ )
89
111
90
112
# Kernels expect dtype=float8_e4m3fn
91
113
a = a .to (torch .float8_e4m3fn )
92
114
b = b .to (torch .float8_e4m3fn )
93
115
94
116
return (a , b , scale_a , scale_b )
95
117
96
- if hasattr (self , 'external_shapes' ) and self .external_shapes : # Check for external shapes loaded from input-loader
118
+ if (
119
+ hasattr (self , "external_shapes" ) and self .external_shapes
120
+ ): # Check for external shapes loaded from input-loader
97
121
for shape in self .external_shapes :
98
122
if len (shape ) == 3 :
99
123
m , n , k = shape
100
124
yield args (m , n , k )
101
125
else :
102
- logger .warning (f"Skipping invalid shape: { shape } , expected [M, N, K]" )
126
+ logger .warning (
127
+ f"Skipping invalid shape: { shape } , expected [M, N, K]"
128
+ )
103
129
elif self .extra_args .llama :
104
130
for m , n , k , _bias in llama_shapes ():
105
131
yield args (m , n , k )
0 commit comments