Skip to content

Commit dbcadfa

Browse files
committed
Support _native_multi_head_attention.
1 parent 1415926 commit dbcadfa

File tree

1 file changed

+50
-26
lines changed

1 file changed

+50
-26
lines changed

graph_net/torch/collect_stats.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ class OpStat:
5555
count: int = 0
5656

5757

58+
def resolve_native_multi_head_attention(*args, **kwargs):
59+
query, key, value = args[0], args[1], args[2]
60+
seq_len, batch_size, embed_dim = query.shape
61+
attn_output = torch.empty(
62+
(seq_len, batch_size, embed_dim), dtype=query.dtype, device="meta"
63+
)
64+
65+
# seq_len_k = key.shape[0]
66+
# num_heads = args[4]
67+
# attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68+
# dtype=query.dtype, device='meta')
69+
return attn_output # , attn_output_weights
70+
71+
5872
def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
5973
attr_itr = node.target.split(".")
6074
val = gm
@@ -65,13 +79,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
6579

6680

6781
def collect_op_stats(model, input_dict):
68-
# FX symbolic trace
6982
try:
83+
# FX symbolic trace
7084
traced = torch.fx.symbolic_trace(model)
7185
# print(traced.graph)
7286
except Exception:
7387
print("Failed to FX symbolic trace")
74-
return None
88+
return False, None
7589

7690
# Use meta tensors as input to avoid actually running the model
7791
meta_input_dict = {}
@@ -80,8 +94,9 @@ def collect_op_stats(model, input_dict):
8094
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
8195
)
8296

83-
node_outputs = {}
97+
is_complete = True
8498
op_stats = {}
99+
node_outputs = {}
85100
for node in traced.graph.nodes:
86101
op_name = None
87102
dtype = None
@@ -99,31 +114,35 @@ def collect_op_stats(model, input_dict):
99114
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
100115
)
101116

102-
if node.op == "call_module":
103-
# classname of module
104-
submod = traced.get_submodule(node.target)
105-
op_name = submod.__class__.__name__
106-
op_func = submod
107-
elif node.op == "call_function":
108-
op_name = node.target.__name__
109-
op_func = node.target
110-
elif node.op == "call_method":
111-
op_name = node.target
112-
self_obj = (
113-
node_outputs[node.args[0].name]
114-
if isinstance(node.args[0], torch.fx.Node)
115-
else node.args[0]
116-
)
117-
op_func = getattr(self_obj, node.target)
118-
node_args = node_args[1:]
119-
120117
try:
121-
out = op_func(*node_args, **node_kwargs)
118+
if node.op == "call_module":
119+
# classname of module
120+
submod = traced.get_submodule(node.target)
121+
op_name = submod.__class__.__name__
122+
op_func = submod
123+
elif node.op == "call_function":
124+
op_name = node.target.__name__
125+
op_func = node.target
126+
elif node.op == "call_method":
127+
op_name = node.target
128+
self_obj = (
129+
node_outputs[node.args[0].name]
130+
if isinstance(node.args[0], torch.fx.Node)
131+
else node.args[0]
132+
)
133+
op_func = getattr(self_obj, node.target)
134+
node_args = node_args[1:]
135+
136+
if op_name == "_native_multi_head_attention":
137+
out = resolve_native_multi_head_attention(*node_args, **node_kwargs)
138+
else:
139+
out = op_func(*node_args, **node_kwargs)
122140
node_outputs[node.name] = out
123141
dtype = out.dtype if isinstance(out, torch.Tensor) else None
124142
except Exception:
125143
print(f"dtype inference failed: node.op={node.op}, op_name={op_name}")
126144
node_outputs[node.name] = None
145+
is_complete = False
127146
elif node.op == "get_attr":
128147
op_name = node.op
129148
out = resolve_get_attr(traced, node)
@@ -149,11 +168,16 @@ def collect_op_stats(model, input_dict):
149168
else:
150169
op_stats[op_name].dtype.add(dtype_str)
151170
op_stats[op_name].count = op_stats[op_name].count + 1
152-
return op_stats
171+
return is_complete, op_stats
153172

154173

155174
def collect_model_stats(model_path, device, log_prompt):
156-
print(f"Collect information for {model_path}")
175+
if not hasattr(collect_model_stats, "_counter"):
176+
collect_model_stats._counter = 0
177+
else:
178+
collect_model_stats._counter += 1
179+
print(f"[{collect_model_stats._counter}] Collect information for {model_path}")
180+
157181
model_class = load_class_from_file(
158182
os.path.join(model_path, "model.py"), "GraphModule"
159183
)
@@ -164,7 +188,7 @@ def collect_model_stats(model_path, device, log_prompt):
164188
num_inputs = 0
165189
num_outputs = 0
166190
dtypes = set()
167-
op_stats = collect_op_stats(model, input_dict)
191+
is_complete, op_stats = collect_op_stats(model, input_dict)
168192
if op_stats is not None:
169193
for op_name, stat in op_stats.items():
170194
if op_name == "placeholder":
@@ -192,7 +216,7 @@ def collect_model_stats(model_path, device, log_prompt):
192216
dtypes_str = "[" + ",".join(dtypes) + "]"
193217
param_dtypes_str = "[" + ",".join(param_dtypes) + "]"
194218
print(
195-
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str}",
219+
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete}",
196220
file=sys.stderr,
197221
flush=True,
198222
)

0 commit comments

Comments
 (0)