Skip to content

Commit 39b1806

Browse files
committed
Refine some codes and fix support for VectorType.
1 parent 6d9b156 commit 39b1806

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

graph_net/paddle/collect_stats.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os
3+
import re
34
import sys
45
import ast
56
import math
@@ -118,6 +119,24 @@ def update_op_stats(self, op_name, op_dtype):
118119
)
119120
self.op_stats[op_name].count += 1
120121

122+
def parse_pir_value_dtypes(self, type_str):
123+
short_form2dtype = {
124+
"f32": "float32",
125+
"f16": "float16",
126+
"bf16": "bfloat16",
127+
"i64": "int64",
128+
}
129+
# type_str: "vec[tensor<1x18x13x9xf32>,tensor<1x9x13x9xf32>]"
130+
matches = re.findall(r"tensor<([^>]+)>", type_str)
131+
dtype_strs = []
132+
for s in matches:
133+
parts = s.split("x")
134+
assert len(parts) > 0
135+
136+
dtype = parts[-1].lower()
137+
dtype_strs.append(short_form2dtype[dtype])
138+
return dtype_strs
139+
121140
def __call__(self, program):
122141
assert isinstance(program, paddle.base.libpaddle.pir.Program)
123142

@@ -129,22 +148,38 @@ def __call__(self, program):
129148
op_name = None
130149
op_dtype = None
131150
if op.name() == "pd_op.data":
151+
op_name = "data"
132152
op_attrs = op.attrs()
133153
op_dtype = op_attrs["dtype"]
134154
self.input_dict[op_attrs["name"]] = {
135155
"dtype": str(op_dtype).replace("paddle.", ""),
136156
"shape": op_attrs["shape"],
137157
}
138-
elif not op.name().startswith("builtin."):
158+
elif op.name().startswith("pd_op."):
139159
self.num_ops += 1
140160
op_name = op.name().replace("pd_op.", "")
141-
if len(op.results()) > 0:
142-
op_dtype = op.results()[0].dtype
143-
144-
if op_name is not None:
145-
self.update_op_stats(op_name, op_dtype)
146-
elif op_dtype is None:
147-
self.num_ops_misses_dtypes += 1
161+
try:
162+
if len(op.results()) > 0:
163+
out = op.results()[0]
164+
if out.is_dense_tensor_type():
165+
op_dtype = out.dtype
166+
else:
167+
# for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined
168+
if op_name in ["split", "split_with_num", "meshgrid"]:
169+
op_dtype = self.parse_pir_value_dtypes(
170+
str(out.type())
171+
)[0]
172+
else:
173+
assert False, f"Unsupport op: {op}"
174+
except Exception:
175+
if self.num_ops_misses_dtypes == 0:
176+
print(f"dtype inference failed for {op_name}")
177+
if op_dtype is not None:
178+
self.update_op_stats(op_name, op_dtype)
179+
else:
180+
self.num_ops_misses_dtypes += 1
181+
elif not op.name().startswith("builtin."):
182+
assert False, f"Unrecognized op: {op}"
148183

149184
if self.num_ops_misses_dtypes > 0:
150185
self.is_complete = False
@@ -281,7 +316,7 @@ def main(args):
281316
cmd = [
282317
"python",
283318
"-m",
284-
"graph_net.torch.collect_stats",
319+
"graph_net.paddle.collect_stats",
285320
f"--device={args.device}",
286321
f"--model-path={root}",
287322
f"--log-prompt={args.log_prompt}",

graph_net/paddle/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def load_converted_list_from_text(file_path):
122122
return [*weight_info, *input_info]
123123

124124

125-
def ConvertToValidNumber(data_type, value):
125+
def convert_to_valid_number(data_type, value):
126126
if value is not None and data_type in [
127127
paddle.float32,
128128
paddle.float16,
@@ -160,10 +160,10 @@ def convert_meta_classes_to_tensors(file_path):
160160
"shape": attrs.get("shape", []),
161161
"dtype": data_type,
162162
"device": attrs.get("device", "gpu"),
163-
"mean": ConvertToValidNumber(data_type, attrs.get("mean", None)),
164-
"std": ConvertToValidNumber(data_type, attrs.get("std", None)),
165-
"min_val": ConvertToValidNumber(data_type, attrs.get("min_val", 0)),
166-
"max_val": ConvertToValidNumber(data_type, attrs.get("max_val", 2)),
163+
"mean": convert_to_valid_number(data_type, attrs.get("mean", None)),
164+
"std": convert_to_valid_number(data_type, attrs.get("std", None)),
165+
"min_val": convert_to_valid_number(data_type, attrs.get("min_val", 0)),
166+
"max_val": convert_to_valid_number(data_type, attrs.get("max_val", 2)),
167167
},
168168
"data": data_value,
169169
"name": attrs.get("name"),

0 commit comments

Comments
 (0)