|
1 | 1 | import argparse
|
2 | 2 | import logging
|
| 3 | +import csv |
| 4 | +import os |
3 | 5 |
|
4 | 6 | from typing import Any, Callable, List, Optional
|
5 | 7 |
|
|
33 | 35 | HAS_TMA = False
|
34 | 36 | logger.warning(f"Failed to import TMA: {e}")
|
35 | 37 |
|
36 |
| - |
37 | 38 | def parse_args(args):
|
38 | 39 | parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
|
39 | 40 | parser.add_argument("--llama", action="store_true")
|
40 | 41 | parser.add_argument("--scaling_rowwise", action="store_true")
|
41 | 42 | parser.add_argument("--m", type=int)
|
42 | 43 | parser.add_argument("--k", type=int)
|
43 | 44 | parser.add_argument("--n", type=int)
|
| 45 | + parser.add_argument("--filepath", type=str, default=None) |
44 | 46 | return parser.parse_args(args)
|
45 | 47 |
|
| 48 | +def read_fp8_shapes(filepath): |
| 49 | + fp8_shapes = [] |
| 50 | + try: |
| 51 | + with open(filepath, 'r', newline='') as csvfile: |
| 52 | + filtered_lines = (line for line in csvfile if line.strip() and not line.lstrip().startswith('#')) |
| 53 | + reader = csv.reader(filtered_lines) |
| 54 | + for row in reader: |
| 55 | + fp8_shapes.append(tuple(map(int, row))) |
| 56 | + except Exception as e: |
| 57 | + logger.error(f"Failed to read fp8 shapes from {filepath}: {e}") |
| 58 | + raise e |
| 59 | + return fp8_shapes |
| 60 | + |
46 | 61 |
|
47 | 62 | class Operator(BenchmarkOperator):
|
48 | 63 | DEFAULT_METRICS = ["tflops", "gbps", "latency"]
|
@@ -70,6 +85,10 @@ def args(m, n, k):
|
70 | 85 | yield args(m, n, k)
|
71 | 86 | elif self.extra_args.m:
|
72 | 87 | yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k)
|
| 88 | + elif self.extra_args.filepath: |
| 89 | + fp8_shapes = read_fp8_shapes(self.extra_args.filepath) |
| 90 | + for m, n, k in fp8_shapes: |
| 91 | + yield args(m, n, k) |
73 | 92 | else:
|
74 | 93 | for i in range(10, 15):
|
75 | 94 | for j in range(0, 4):
|
@@ -114,8 +133,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable:
|
114 | 133 | scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
|
115 | 134 | out_dtype = torch.bfloat16
|
116 | 135 | else:
|
117 |
| - scale_a = torch.tensor(1.0, device=a.device) |
118 |
| - scale_b = torch.tensor(1.0, device=a.device) |
| 136 | + scale_a = torch.tensor(1.0, dtype=torch.float32, device=a.device) |
| 137 | + scale_b = torch.tensor(1.0, dtype=torch.float32, device=a.device) |
119 | 138 | out_dtype = torch.float16
|
120 | 139 | f = lambda a, b: torch._scaled_mm(
|
121 | 140 | a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
|
|
0 commit comments