Skip to content

Commit b5bd25b

Browse files
committed
Support _native_multi_head_attention.
1 parent 1415926 commit b5bd25b

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

graph_net/torch/collect_stats.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ class OpStat:
5555
count: int = 0
5656

5757

58+
def resolve_native_multi_head_attention(*args, **kwargs):
59+
query, key, value = args[0], args[1], args[2]
60+
seq_len, batch_size, embed_dim = query.shape
61+
attn_output = torch.empty(
62+
(seq_len, batch_size, embed_dim), dtype=query.dtype, device="meta"
63+
)
64+
65+
# seq_len_k = key.shape[0]
66+
# num_heads = args[4]
67+
# attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68+
# dtype=query.dtype, device='meta')
69+
return attn_output # , attn_output_weights
70+
71+
5872
def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
5973
attr_itr = node.target.split(".")
6074
val = gm
@@ -65,8 +79,8 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
6579

6680

6781
def collect_op_stats(model, input_dict):
68-
# FX symbolic trace
6982
try:
83+
# FX symbolic trace
7084
traced = torch.fx.symbolic_trace(model)
7185
# print(traced.graph)
7286
except Exception:
@@ -118,7 +132,10 @@ def collect_op_stats(model, input_dict):
118132
node_args = node_args[1:]
119133

120134
try:
121-
out = op_func(*node_args, **node_kwargs)
135+
if op_name == "_native_multi_head_attention":
136+
out = resolve_native_multi_head_attention(*node_args, **node_kwargs)
137+
else:
138+
out = op_func(*node_args, **node_kwargs)
122139
node_outputs[node.name] = out
123140
dtype = out.dtype if isinstance(out, torch.Tensor) else None
124141
except Exception:

0 commit comments

Comments
 (0)