@@ -55,6 +55,20 @@ class OpStat:
55
55
count : int = 0
56
56
57
57
58
+ def resolve_native_multi_head_attention (* args , ** kwargs ):
59
+ query , key , value = args [0 ], args [1 ], args [2 ]
60
+ seq_len , batch_size , embed_dim = query .shape
61
+ attn_output = torch .empty (
62
+ (seq_len , batch_size , embed_dim ), dtype = query .dtype , device = "meta"
63
+ )
64
+
65
+ # seq_len_k = key.shape[0]
66
+ # num_heads = args[4]
67
+ # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68
+ # dtype=query.dtype, device='meta')
69
+ return attn_output # , attn_output_weights
70
+
71
+
58
72
def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
59
73
attr_itr = node .target .split ("." )
60
74
val = gm
@@ -65,8 +79,8 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
65
79
66
80
67
81
def collect_op_stats (model , input_dict ):
68
- # FX symbolic trace
69
82
try :
83
+ # FX symbolic trace
70
84
traced = torch .fx .symbolic_trace (model )
71
85
# print(traced.graph)
72
86
except Exception :
@@ -118,7 +132,10 @@ def collect_op_stats(model, input_dict):
118
132
node_args = node_args [1 :]
119
133
120
134
try :
121
- out = op_func (* node_args , ** node_kwargs )
135
+ if op_name == "_native_multi_head_attention" :
136
+ out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
137
+ else :
138
+ out = op_func (* node_args , ** node_kwargs )
122
139
node_outputs [node .name ] = out
123
140
dtype = out .dtype if isinstance (out , torch .Tensor ) else None
124
141
except Exception :
0 commit comments