Skip to content

Commit 207889e

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add amax as default per-row scaling factor for fp8_gemm benchmark (#341)
Summary: Pull Request resolved: #341 Add `amax` (absolute maximum) as the default scaling factor for per-row scaling for fp8 GEMMs, as is used in practice. Reviewed By: NikhilAPatel, xuzhao9 Differential Revision: D80590746
1 parent 0e0d6f6 commit 207889e

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,32 @@ def _get_dtype(self):
6363
return torch.float16
6464

6565
def get_input_iter(self):
66-
def _get_scale_per_tensor(x: torch.Tensor, custom_scale: float = None) -> torch.Tensor:
66+
def _get_scale_per_tensor(
67+
x: torch.Tensor, custom_scale: float = None
68+
) -> torch.Tensor:
6769
# For tensor-wise scaling, kernel requires a float32 scale tensor
6870
if custom_scale:
6971
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
7072
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
7173
return scale.to(torch.float32)
7274

75+
def _get_scale_per_row(
76+
x: torch.Tensor, transpose: bool = False
77+
) -> torch.Tensor:
78+
if transpose: # scale_b.shape should be [1, N]
79+
scale = (
80+
torch.finfo(torch.float8_e4m3fn).max
81+
/ x.abs().max(dim=0, keepdim=True).values
82+
)
83+
else: # scale_a.shape should be [M, 1]
84+
scale = (
85+
torch.finfo(torch.float8_e4m3fn).max
86+
/ x.abs().max(dim=1, keepdim=True).values
87+
)
88+
return scale.to(
89+
torch.float32
90+
) # For row-wise scaling, kernel requires a float32 scale tensor
91+
7392
def args(m, n, k):
7493
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
7594
b = (
@@ -80,12 +99,15 @@ def args(m, n, k):
8099
)
81100

82101
if self.extra_args.scaling_rowwise:
83-
M, N = a.shape[0], b.shape[1]
84-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
85-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
102+
scale_a = _get_scale_per_row(a)
103+
scale_b = _get_scale_per_row(b, transpose=True)
86104
else:
87-
scale_a = _get_scale_per_tensor(a, custom_scale=self.extra_args.per_tensor_scale_a)
88-
scale_b = _get_scale_per_tensor(b, custom_scale=self.extra_args.per_tensor_scale_b)
105+
scale_a = _get_scale_per_tensor(
106+
a, custom_scale=self.extra_args.per_tensor_scale_a
107+
)
108+
scale_b = _get_scale_per_tensor(
109+
b, custom_scale=self.extra_args.per_tensor_scale_b
110+
)
89111

90112
# Kernels expect dtype=float8_e4m3fn
91113
a = a.to(torch.float8_e4m3fn)

0 commit comments

Comments
 (0)