Skip to content

Commit 6d9b156

Browse files
committed
Implement collecting stats for paddle.
1 parent e7fd651 commit 6d9b156

File tree

1 file changed

+342
-0
lines changed

1 file changed

+342
-0
lines changed

graph_net/paddle/collect_stats.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
import argparse
2+
import os
3+
import sys
4+
import ast
5+
import math
6+
import importlib
7+
import inspect
8+
import subprocess
9+
from datetime import datetime
10+
from typing import Type
11+
from dataclasses import dataclass, field
12+
from collections import defaultdict
13+
14+
import paddle
15+
from graph_net.paddle import utils
16+
17+
18+
def is_single_model_dir(model_dir):
19+
return os.path.isfile(f"{model_dir}/graph_net.json")
20+
21+
22+
def load_class_from_file(file_path: str, class_name: str) -> Type[paddle.nn.Layer]:
23+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
24+
unnamed = importlib.util.module_from_spec(spec)
25+
spec.loader.exec_module(unnamed)
26+
model_class = getattr(unnamed, class_name, None)
27+
return model_class
28+
29+
30+
def get_argument_name_and_types(model_class, func_name):
31+
argument_name2types = {}
32+
for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction):
33+
if name == func_name:
34+
for arg_name, arg in inspect.signature(func).parameters.items():
35+
if arg_name != "self":
36+
argument_name2types[arg_name] = (
37+
None if arg.annotation is inspect._empty else arg.annotation
38+
)
39+
return argument_name2types
40+
41+
42+
def get_number_of_returns(file_path, class_name, func_name):
43+
source = None
44+
with open(f"{file_path}", "r") as f:
45+
source = f.read()
46+
47+
tree = ast.parse(source)
48+
for node in tree.body:
49+
if isinstance(node, ast.ClassDef) and node.name == class_name:
50+
for f in node.body:
51+
if isinstance(f, ast.FunctionDef) and f.name == func_name:
52+
for stmt in ast.walk(f):
53+
if isinstance(stmt, ast.Return):
54+
if stmt.value is None:
55+
return 0
56+
elif isinstance(stmt.value, ast.Tuple):
57+
return len(stmt.value.elts)
58+
else:
59+
return 1
60+
return 0
61+
62+
63+
def read_graph_source_and_tag(model_path):
64+
try:
65+
with open(os.path.join(model_path, "graph_net.json"), "r") as f:
66+
data = json.load(f)
67+
return data["source"], data["heuristic_tag"]
68+
except Exception:
69+
if "PaddleX" in model_path:
70+
return "PaddleX", "computer_vision"
71+
elif "PaddleNLP" in model_path:
72+
return "PaddleNLP", "nlp"
73+
elif "PaddleScience" in model_path:
74+
return "PaddleScience", "scientific_computing"
75+
else:
76+
return "unknown", "unknown"
77+
78+
79+
def get_input_spec(model_path):
80+
inputs_params_list = utils.load_converted_list_from_text(f"{model_path}")
81+
input_spec = [None] * len(inputs_params_list)
82+
for i, v in enumerate(inputs_params_list):
83+
dtype = v["info"]["dtype"]
84+
shape = v["info"]["shape"]
85+
input_spec[i] = paddle.static.InputSpec(shape, dtype)
86+
return input_spec
87+
88+
89+
@dataclass
90+
class OpStat:
91+
op_name: str
92+
op_dtypes: dict[str, int] = field(default_factory=dict)
93+
count: int = 0
94+
95+
def update(self, other):
96+
if isinstance(other, OpStat) and self.op_name == other.op_name:
97+
self.count += other.count
98+
for name, count in other.op_dtypes.items():
99+
self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count
100+
101+
102+
class ProgramAnalyzer:
103+
def __init__(self):
104+
self.op_stats = {}
105+
self.input_dict = {}
106+
self.num_ops = 0
107+
self.num_ops_misses_dtypes = 0
108+
self.is_complete = True
109+
110+
def update_op_stats(self, op_name, op_dtype):
111+
if op_name is not None:
112+
dtype_str = str(op_dtype).replace("paddle.", "")
113+
if self.op_stats.get(op_name, None) is None:
114+
self.op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1)
115+
else:
116+
self.op_stats[op_name].op_dtypes[dtype_str] = (
117+
self.op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
118+
)
119+
self.op_stats[op_name].count += 1
120+
121+
def __call__(self, program):
122+
assert isinstance(program, paddle.base.libpaddle.pir.Program)
123+
124+
self.op_stats = {}
125+
self.num_ops_misses_dtypes = 0
126+
self.num_ops = 0
127+
for block in program.blocks:
128+
for op in block.ops:
129+
op_name = None
130+
op_dtype = None
131+
if op.name() == "pd_op.data":
132+
op_attrs = op.attrs()
133+
op_dtype = op_attrs["dtype"]
134+
self.input_dict[op_attrs["name"]] = {
135+
"dtype": str(op_dtype).replace("paddle.", ""),
136+
"shape": op_attrs["shape"],
137+
}
138+
elif not op.name().startswith("builtin."):
139+
self.num_ops += 1
140+
op_name = op.name().replace("pd_op.", "")
141+
if len(op.results()) > 0:
142+
op_dtype = op.results()[0].dtype
143+
144+
if op_name is not None:
145+
self.update_op_stats(op_name, op_dtype)
146+
elif op_dtype is None:
147+
self.num_ops_misses_dtypes += 1
148+
149+
if self.num_ops_misses_dtypes > 0:
150+
self.is_complete = False
151+
152+
def summary(self):
153+
print(
154+
f"Totally {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes."
155+
)
156+
157+
158+
def collect_op_stats(model, model_path):
159+
assert isinstance(model, paddle.nn.Layer), f"{type(model)=}"
160+
try:
161+
static_model = paddle.jit.to_static(
162+
model,
163+
input_spec=get_input_spec(model_path),
164+
full_graph=True,
165+
backend=None,
166+
)
167+
static_model.eval()
168+
program = static_model.forward.concrete_program.main_program
169+
170+
program_analyzer = ProgramAnalyzer()
171+
program_analyzer(program)
172+
program_analyzer.summary()
173+
return program_analyzer
174+
except Exception:
175+
print("Failed with to_static")
176+
return None
177+
178+
179+
def collect_model_stats(model_path, log_prompt):
180+
file_path = os.path.join(model_path, "model.py")
181+
model_class = load_class_from_file(file_path, "GraphModule")
182+
model = model_class()
183+
num_outputs = get_number_of_returns(file_path, "GraphModule", "forward")
184+
185+
model_size = 0
186+
input_dtypes = {}
187+
param_dtypes = {}
188+
ops_count_dict = {}
189+
op_dtypes = {}
190+
191+
program_analyzer = collect_op_stats(model, model_path)
192+
if program_analyzer is not None:
193+
for op_name, stat in sorted(program_analyzer.op_stats.items()):
194+
ops_count_dict[op_name] = stat.count
195+
for dtype_str, num in stat.op_dtypes.items():
196+
if dtype_str is not None and dtype_str != "None":
197+
op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num
198+
199+
inputs_params = utils.load_converted_from_text(f"{model_path}")
200+
params = inputs_params["weight_info"]
201+
inputs = inputs_params["input_info"]
202+
203+
for name, value in program_analyzer.input_dict.items():
204+
dtype_str = value["dtype"]
205+
if name in params.keys():
206+
param_numel = math.prod(value["shape"])
207+
model_size += param_numel
208+
param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1
209+
elif name in inputs.keys():
210+
input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1
211+
212+
model_size_in_billion = model_size / 1e9
213+
num_params = sum(param_dtypes.values())
214+
num_inputs = sum(input_dtypes.values())
215+
num_ops = sum(ops_count_dict.values())
216+
source, heuristic_tag = read_graph_source_and_tag(model_path)
217+
method = "to_static"
218+
is_complete = (
219+
program_analyzer.is_complete if program_analyzer is not None else False
220+
)
221+
222+
def dict_to_string(d):
223+
kv_list = [f"{k}:{v}" for k, v in d.items()]
224+
return " ".join(kv_list)
225+
226+
def print_with_log_prompt(key, value):
227+
print(
228+
f"{log_prompt} [ModelStats.{key}] model_path:{model_path} {value}",
229+
flush=True,
230+
)
231+
232+
print_with_log_prompt("num_inputs", num_inputs)
233+
print_with_log_prompt("num_params", num_params)
234+
print_with_log_prompt("num_outputs", num_outputs)
235+
print_with_log_prompt("num_ops", num_ops)
236+
print_with_log_prompt("model_size", f"{model_size_in_billion}B")
237+
print_with_log_prompt("input_dtypes", dict_to_string(input_dtypes))
238+
print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes))
239+
print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes))
240+
print_with_log_prompt("ops", dict_to_string(ops_count_dict))
241+
print_with_log_prompt("source", source)
242+
print_with_log_prompt("heuristic_tag", heuristic_tag)
243+
print_with_log_prompt("method", method)
244+
print_with_log_prompt("is_complete", is_complete)
245+
246+
247+
def main(args):
248+
if args.model_path is not None:
249+
assert os.path.isdir(args.model_path)
250+
assert is_single_model_dir(args.model_path)
251+
timestamp_sec = datetime.now().timestamp()
252+
dt = datetime.fromtimestamp(timestamp_sec)
253+
formatted_dt = dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
254+
print(f"[{formatted_dt}] Collect information for {args.model_path}")
255+
collect_model_stats(args.model_path, args.log_prompt)
256+
else:
257+
graph_net_samples_path = (
258+
(graph_net.paddle.samples_util.get_default_samples_directory())
259+
if args.graph_net_samples_path is None
260+
else args.graph_net_samples_path
261+
)
262+
263+
previous_failed_model_pathes = []
264+
if args.previous_collect_result_path is not None:
265+
with open(args.previous_collect_result_path, "r") as f:
266+
for line in f.readlines():
267+
if "[ModelStats]" in line:
268+
fields = line.strip().split()
269+
model_path = fields[2].split(":")[-1]
270+
is_complete = fields[-1].split(":")[-1]
271+
if is_complete == "False":
272+
previous_failed_model_pathes.append(model_path)
273+
274+
i = 0
275+
for root, dirs, files in os.walk(graph_net_samples_path):
276+
if is_single_model_dir(root) and (
277+
args.previous_collect_result_path is None
278+
or root in previous_failed_model_pathes
279+
):
280+
print(f"[{i}] Collect information for {root}")
281+
cmd = [
282+
"python",
283+
"-m",
284+
"graph_net.torch.collect_stats",
285+
f"--device={args.device}",
286+
f"--model-path={root}",
287+
f"--log-prompt={args.log_prompt}",
288+
]
289+
result = subprocess.run(
290+
cmd,
291+
stdout=subprocess.PIPE,
292+
stderr=subprocess.PIPE,
293+
text=True,
294+
timeout=600,
295+
)
296+
print(result.stdout)
297+
if result.returncode != 0:
298+
print(result.stderr)
299+
i += 1
300+
301+
302+
if __name__ == "__main__":
303+
parser = argparse.ArgumentParser(
304+
description="Collect stats for computation graph samples. return 0 if success"
305+
)
306+
parser.add_argument(
307+
"--device",
308+
type=str,
309+
required=False,
310+
default="cuda",
311+
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
312+
)
313+
parser.add_argument(
314+
"--model-path",
315+
type=str,
316+
required=False,
317+
default=None,
318+
help="Computation graph sample directory. e.g '../../paddle_samples/PaddleX/ResNet18'",
319+
)
320+
parser.add_argument(
321+
"--graph-net-samples-path",
322+
type=str,
323+
required=False,
324+
default=None,
325+
help="GraphNet samples directory. e.g '../../paddle_samples'",
326+
)
327+
parser.add_argument(
328+
"--previous-collect-result-path",
329+
type=str,
330+
required=False,
331+
default=None,
332+
help="Previous collect result path, use to recollect the failed cases",
333+
)
334+
parser.add_argument(
335+
"--log-prompt",
336+
type=str,
337+
required=False,
338+
default="graph-net-collect-stats-log",
339+
help="Log prompt for stats log filtering.",
340+
)
341+
args = parser.parse_args()
342+
main(args=args)

0 commit comments

Comments
 (0)