Skip to content

Commit c937d85

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add benchmarking on shapes from CSV files to fp8_gemm
Summary: Add ability to benchmark fp8_gemm kernels on shapes from CSV files. Add CLI argument to consume file path for CSV file. Differential Revision: D80381352
1 parent a404ea7 commit c937d85

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
22
import logging
3+
import csv
4+
import os
35

46
from typing import Any, Callable, List, Optional
57

@@ -33,16 +35,29 @@
3335
HAS_TMA = False
3436
logger.warning(f"Failed to import TMA: {e}")
3537

36-
3738
def parse_args(args):
3839
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
3940
parser.add_argument("--llama", action="store_true")
4041
parser.add_argument("--scaling_rowwise", action="store_true")
4142
parser.add_argument("--m", type=int)
4243
parser.add_argument("--k", type=int)
4344
parser.add_argument("--n", type=int)
45+
parser.add_argument("--filepath", type=str, default=None)
4446
return parser.parse_args(args)
4547

48+
def read_fp8_shapes(filepath):
49+
fp8_shapes = []
50+
try:
51+
with open(filepath, 'r', newline='') as csvfile:
52+
filtered_lines = (line for line in csvfile if line.strip() and not line.lstrip().startswith('#'))
53+
reader = csv.reader(filtered_lines)
54+
for row in reader:
55+
fp8_shapes.append(tuple(map(int, row)))
56+
except Exception as e:
57+
logger.error(f"Failed to read fp8 shapes from {filepath}: {e}")
58+
raise e
59+
return fp8_shapes
60+
4661

4762
class Operator(BenchmarkOperator):
4863
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
@@ -70,6 +85,10 @@ def args(m, n, k):
7085
yield args(m, n, k)
7186
elif self.extra_args.m:
7287
yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k)
88+
elif self.extra_args.filepath:
89+
fp8_shapes = read_fp8_shapes(self.extra_args.filepath)
90+
for m, n, k in fp8_shapes:
91+
yield args(m, n, k)
7392
else:
7493
for i in range(10, 15):
7594
for j in range(0, 4):
@@ -114,8 +133,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable:
114133
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
115134
out_dtype = torch.bfloat16
116135
else:
117-
scale_a = torch.tensor(1.0, device=a.device)
118-
scale_b = torch.tensor(1.0, device=a.device)
136+
scale_a = torch.tensor(1.0, dtype=torch.float32, device=a.device)
137+
scale_b = torch.tensor(1.0, dtype=torch.float32, device=a.device)
119138
out_dtype = torch.float16
120139
f = lambda a, b: torch._scaled_mm(
121140
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype

0 commit comments

Comments
 (0)