1
1
import argparse
2
+ import csv
2
3
import logging
3
4
4
5
from typing import Any , Callable , List , Optional
@@ -41,9 +42,28 @@ def parse_args(args):
41
42
parser .add_argument ("--m" , type = int )
42
43
parser .add_argument ("--k" , type = int )
43
44
parser .add_argument ("--n" , type = int )
45
+ parser .add_argument ("--filepath" , type = str , default = None )
44
46
return parser .parse_args (args )
45
47
46
48
49
+ def read_fp8_shapes (filepath ):
50
+ fp8_shapes = []
51
+ try :
52
+ with open (filepath , "r" , newline = "" ) as csvfile :
53
+ filtered_lines = (
54
+ line
55
+ for line in csvfile
56
+ if line .strip () and not line .lstrip ().startswith ("#" )
57
+ )
58
+ reader = csv .reader (filtered_lines )
59
+ for row in reader :
60
+ fp8_shapes .append (tuple (map (int , row )))
61
+ except Exception as e :
62
+ logger .error (f"Failed to read fp8 shapes from { filepath } : { e } " )
63
+ raise e
64
+ return fp8_shapes
65
+
66
+
47
67
class Operator (BenchmarkOperator ):
48
68
DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
49
69
DEFAULT_PRECISION = "fp8"
@@ -70,6 +90,10 @@ def args(m, n, k):
70
90
yield args (m , n , k )
71
91
elif self .extra_args .m :
72
92
yield args (self .extra_args .m , self .extra_args .n , self .extra_args .k )
93
+ elif self .extra_args .filepath :
94
+ fp8_shapes = read_fp8_shapes (self .extra_args .filepath )
95
+ for m , n , k in fp8_shapes :
96
+ yield args (m , n , k )
73
97
else :
74
98
for i in range (10 , 15 ):
75
99
for j in range (0 , 4 ):
@@ -114,8 +138,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable:
114
138
scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
115
139
out_dtype = torch .bfloat16
116
140
else :
117
- scale_a = torch .tensor (1.0 , device = a .device )
118
- scale_b = torch .tensor (1.0 , device = a .device )
141
+ scale_a = torch .tensor (1.0 , dtype = torch . float32 , device = a .device )
142
+ scale_b = torch .tensor (1.0 , dtype = torch . float32 , device = b .device )
119
143
out_dtype = torch .float16
120
144
f = lambda a , b : torch ._scaled_mm (
121
145
a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
0 commit comments