Skip to content

Commit 272ab55

Browse files
committed
Implement a function to collect the model's execution stats.
1 parent c1bc381 commit 272ab55

File tree

2 files changed

+205
-1
lines changed

2 files changed

+205
-1
lines changed

graph_net/torch/collect_stats.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import argparse
2+
import os
3+
import importlib
4+
from typing import Type
5+
from dataclasses import dataclass, field
6+
from collections import defaultdict
7+
8+
import torch
9+
from torch.fx.passes.shape_prop import ShapeProp
10+
from graph_net.torch import utils
11+
12+
13+
def is_single_model_dir(model_dir):
14+
return os.path.isfile(f"{model_dir}/graph_net.json")
15+
16+
17+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
18+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
19+
unnamed = importlib.util.module_from_spec(spec)
20+
spec.loader.exec_module(unnamed)
21+
model_class = getattr(unnamed, class_name, None)
22+
return model_class
23+
24+
25+
def get_input_dict(model_path, device):
26+
inputs_params = utils.load_converted_from_text(f"{model_path}")
27+
params = inputs_params["weight_info"]
28+
for tensor_meta in params.values():
29+
if hasattr(tensor_meta, "device"):
30+
tensor_meta.device = device
31+
return {
32+
k: utils.replay_tensor(v).to(torch.device(device)) for k, v in params.items()
33+
}
34+
35+
36+
@dataclass
37+
class OpStat:
38+
op_name: str
39+
dtype: set[str] = field(default_factory=set)
40+
count: int = 0
41+
42+
43+
def collect_op_stats(model, input_dict):
44+
# Use meta tensors as input to avoid actually running the model
45+
meta_input_dict = {}
46+
for name, x in input_dict.items():
47+
meta_input_dict[name] = (
48+
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
49+
)
50+
51+
# FX symbolic trace
52+
traced = torch.fx.symbolic_trace(model)
53+
# print(traced.graph)
54+
55+
node_outputs = {}
56+
op_stats = {}
57+
for node in traced.graph.nodes:
58+
op_name = None
59+
dtype = None
60+
if node.op == "placeholder":
61+
node_outputs[node.name] = meta_input_dict[node.target]
62+
op_name = node.op
63+
dtype = node_outputs[node.name].dtype
64+
elif node.op in ["call_function", "call_method", "call_module"]:
65+
node_args = []
66+
for arg in node.args:
67+
node_args.append(
68+
node_outputs[arg.name] if hasattr(arg, "name") else arg
69+
)
70+
node_kwargs = {}
71+
for k, v in node.kwargs.items():
72+
node_kwargs[k] = node_outputs[v.name] if hasattr(v, "name") else v
73+
74+
if node.op == "call_module":
75+
# classname of module
76+
submod = dict(traced.named_modules())[node.target]
77+
op_name = submod.__class__.__name__
78+
try:
79+
out = submod(*node_args, **node_kwargs)
80+
node_outputs[node.name] = out
81+
dtype = out.dtype if isinstance(out, torch.Tensor) else None
82+
except Exception:
83+
node_outputs[node.name] = None
84+
elif node.op in ["call_function", "call_method"]:
85+
op_name = (
86+
node.target.__name__ if node.op == "call_function" else node.target
87+
)
88+
try:
89+
out = node.target(*node_args, **node_kwargs)
90+
node_outputs[node.name] = out
91+
dtype = out.dtype if isinstance(out, torch.Tensor) else None
92+
except Exception:
93+
print(f"Dtype inference failed: op_name={op_name}")
94+
node_outputs[node.name] = None
95+
elif node.op == "output":
96+
op_name = node.op
97+
node_args = []
98+
for arg in node.args:
99+
node_args.append(
100+
node_outputs[arg.name] if hasattr(arg, "name") else arg
101+
)
102+
node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args
103+
dtype = (
104+
node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None
105+
)
106+
else:
107+
assert False
108+
109+
if op_name is not None:
110+
dtype_str = str(dtype).replace("torch.", "") if dtype is not None else None
111+
if op_stats.get(op_name, None) is None:
112+
op_stats[op_name] = OpStat(op_name, {dtype_str}, 1)
113+
else:
114+
op_stats[op_name].dtype.add(dtype_str)
115+
op_stats[op_name].count = op_stats[op_name].count + 1
116+
return op_stats
117+
118+
119+
def collect_model_stats(model_path, device):
120+
print(f"Collect information for {model_path}")
121+
model_class = load_class_from_file(
122+
os.path.join(model_path, "model.py"), "GraphModule"
123+
)
124+
model = model_class()
125+
input_dict = get_input_dict(model_path, device)
126+
127+
num_ops = 0
128+
num_inputs = 0
129+
num_outputs = 0
130+
dtype = set()
131+
op_stats = collect_op_stats(model, input_dict)
132+
for op_name, stat in op_stats.items():
133+
if op_name == "placeholder":
134+
num_inputs += stat.count
135+
elif op_name == "output":
136+
num_outputs += stat.count
137+
else:
138+
num_ops += stat.count
139+
for v in stat.dtype:
140+
if v is not None:
141+
dtype.add(v)
142+
143+
# num_params_in_gb = 0
144+
# for name, tensor in input_dict.items():
145+
# if isinstance
146+
147+
print(f"Information of {model_path}:")
148+
print(f"- num_inputs : {num_inputs}")
149+
print(f"- num_outputs : {num_outputs}")
150+
print(f"- num_ops : {num_ops}")
151+
print(f"- dtype : {dtype}")
152+
153+
154+
def main(args):
155+
if args.model_path is not None:
156+
assert os.path.isdir(args.model_path)
157+
assert is_single_model_dir(args.model_path)
158+
collect_model_stats(args.model_path, args.device)
159+
else:
160+
graph_net_samples_path = (
161+
(graph_net.torch.samples_util.get_default_samples_directory())
162+
if args.graph_net_samples_path is None
163+
else args.graph_net_samples_path
164+
)
165+
for root, dirs, files in os.walk(graph_net_samples_path):
166+
if is_single_model_dir(root):
167+
collect_model_stats(root, args.device)
168+
169+
170+
if __name__ == "__main__":
171+
parser = argparse.ArgumentParser(
172+
description="Validate a computation graph sample. return 0 if success"
173+
)
174+
parser.add_argument(
175+
"--device",
176+
type=str,
177+
required=False,
178+
default="cuda",
179+
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
180+
)
181+
parser.add_argument(
182+
"--model-path",
183+
type=str,
184+
required=False,
185+
default=None,
186+
help="Computation graph sample directory. e.g '../../samples/torch/resnet18'",
187+
)
188+
parser.add_argument(
189+
"--graph-net-samples-path",
190+
type=str,
191+
required=False,
192+
default=None,
193+
help="GraphNet samples directory. e.g '../../samples'",
194+
)
195+
parser.add_argument(
196+
"--workspace",
197+
default=os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "./workspace"),
198+
help="temporary directory for validation (default: env var GRAPH_NET_EXTRACT_WORKSPACE). ",
199+
)
200+
args = parser.parse_args()
201+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
202+
203+
main(args=args)

graph_net/torch/test_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from . import utils
21
import argparse
32
import importlib.util
43
import inspect
@@ -14,6 +13,8 @@
1413
import json
1514
import numpy as np
1615
import platform
16+
17+
from graph_net.torch import utils
1718
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
1819
from graph_net.torch.backend.tvm_backend import TvmBackend
1920
from graph_net.torch.backend.xla_backend import XlaBackend

0 commit comments

Comments
 (0)