Skip to content

Commit b18d2ba

Browse files
authored
Enable testing preconfigured internal shapes
Differential Revision: D85905941 Pull Request resolved: #610
1 parent 0531069 commit b18d2ba

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tritonbench/operators/addmm/data_io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ def parse_args(args: List[str]) -> argparse.Namespace:
1111
parser.add_argument("--col-major", type=bool, default=False)
1212
parser.add_argument("--large-k-shapes", type=bool, default=False)
1313
parser.add_argument("--bias-1D-y", type=bool, default=False)
14+
parser.add_argument("--config", type=str, default=None)
1415
args = parser.parse_args(args)
1516
return args

tritonbench/operators/addmm/operator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch._inductor.config as inductor_config
77
import triton
8+
from tritonbench.utils.env_utils import is_fbcode
89

910
try:
1011
from hammer.ops.triton.triton_hstu_linear import triton_addmm
@@ -26,6 +27,11 @@
2627

2728
from .data_io import parse_args
2829

30+
if is_fbcode():
31+
from tritonbench.utils.fb.addmm_prod import get_prod_shapes
32+
else:
33+
get_prod_shapes = lambda x: None
34+
2935

3036
# Shape encoding information: (M, K, N, BIAS_1D_Y)
3137
BUILDIN_SHAPES = [
@@ -88,7 +94,10 @@ def __init__(
8894
):
8995
super().__init__(tb_args, extra_args)
9096
addmm_args = parse_args(self.extra_args)
91-
if addmm_args.m and addmm_args.n and addmm_args.k and addmm_args.bias_1D_y:
97+
prod_shapes = get_prod_shapes(addmm_args.config)
98+
if prod_shapes:
99+
self.shapes = prod_shapes
100+
elif addmm_args.m and addmm_args.n and addmm_args.k and addmm_args.bias_1D_y:
92101
self.shapes = [
93102
(addmm_args.m, addmm_args.k, addmm_args.n, addmm_args.bias_1D_y)
94103
]

0 commit comments

Comments
 (0)