File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed
tritonbench/operators/addmm Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -11,5 +11,6 @@ def parse_args(args: List[str]) -> argparse.Namespace:
1111 parser .add_argument ("--col-major" , type = bool , default = False )
1212 parser .add_argument ("--large-k-shapes" , type = bool , default = False )
1313 parser .add_argument ("--bias-1D-y" , type = bool , default = False )
14+ parser .add_argument ("--config" , type = str , default = None )
1415 args = parser .parse_args (args )
1516 return args
Original file line number Diff line number Diff line change 55import torch
66import torch ._inductor .config as inductor_config
77import triton
8+ from tritonbench .utils .env_utils import is_fbcode
89
910try :
1011 from hammer .ops .triton .triton_hstu_linear import triton_addmm
2627
2728from .data_io import parse_args
2829
30+ if is_fbcode ():
31+ from tritonbench .utils .fb .addmm_prod import get_prod_shapes
32+ else :
33+ get_prod_shapes = lambda x : None
34+
2935
3036# Shape encoding information: (M, K, N, BIAS_1D_Y)
3137BUILDIN_SHAPES = [
@@ -88,7 +94,10 @@ def __init__(
8894 ):
8995 super ().__init__ (tb_args , extra_args )
9096 addmm_args = parse_args (self .extra_args )
91- if addmm_args .m and addmm_args .n and addmm_args .k and addmm_args .bias_1D_y :
97+ prod_shapes = get_prod_shapes (addmm_args .config )
98+ if prod_shapes :
99+ self .shapes = prod_shapes
100+ elif addmm_args .m and addmm_args .n and addmm_args .k and addmm_args .bias_1D_y :
92101 self .shapes = [
93102 (addmm_args .m , addmm_args .k , addmm_args .n , addmm_args .bias_1D_y )
94103 ]
You can’t perform that action at this time.
0 commit comments