Skip to content

Commit 56cba4d

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add benchmarking on shapes from CSV files to fp8_gemm (#332)
Summary: Add ability to benchmark fp8_gemm kernels on shapes from CSV files. Add CLI argument to consume file path for CSV file. Differential Revision: D80381352
1 parent a404ea7 commit 56cba4d

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import csv
23
import logging
34

45
from typing import Any, Callable, List, Optional
@@ -41,9 +42,28 @@ def parse_args(args):
4142
parser.add_argument("--m", type=int)
4243
parser.add_argument("--k", type=int)
4344
parser.add_argument("--n", type=int)
45+
parser.add_argument("--filepath", type=str, default=None)
4446
return parser.parse_args(args)
4547

4648

49+
def read_fp8_shapes(filepath):
50+
fp8_shapes = []
51+
try:
52+
with open(filepath, "r", newline="") as csvfile:
53+
filtered_lines = (
54+
line
55+
for line in csvfile
56+
if line.strip() and not line.lstrip().startswith("#")
57+
)
58+
reader = csv.reader(filtered_lines)
59+
for row in reader:
60+
fp8_shapes.append(tuple(map(int, row)))
61+
except Exception as e:
62+
logger.error(f"Failed to read fp8 shapes from {filepath}: {e}")
63+
raise e
64+
return fp8_shapes
65+
66+
4767
class Operator(BenchmarkOperator):
4868
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
4969
DEFAULT_PRECISION = "fp8"
@@ -70,6 +90,10 @@ def args(m, n, k):
7090
yield args(m, n, k)
7191
elif self.extra_args.m:
7292
yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k)
93+
elif self.extra_args.filepath:
94+
fp8_shapes = read_fp8_shapes(self.extra_args.filepath)
95+
for m, n, k in fp8_shapes:
96+
yield args(m, n, k)
7397
else:
7498
for i in range(10, 15):
7599
for j in range(0, 4):
@@ -114,8 +138,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable:
114138
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
115139
out_dtype = torch.bfloat16
116140
else:
117-
scale_a = torch.tensor(1.0, device=a.device)
118-
scale_b = torch.tensor(1.0, device=a.device)
141+
scale_a = torch.tensor(1.0, dtype=torch.float32, device=a.device)
142+
scale_b = torch.tensor(1.0, dtype=torch.float32, device=b.device)
119143
out_dtype = torch.float16
120144
f = lambda a, b: torch._scaled_mm(
121145
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype

tritonbench/operators/gemm/stream_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,6 @@ def grid(META):
646646
K, #
647647
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
648648
ENABLE_BUFFER_OPS_ASSUMES=True, #
649-
NUM_SMS=num_sms #
649+
NUM_SMS=num_sms, #
650650
)
651651
return c

0 commit comments

Comments
 (0)