Skip to content

Commit 7c89cfd

Browse files
committed
Refine some codes.
1 parent 6d9b156 commit 7c89cfd

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

graph_net/paddle/collect_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def main(args):
281281
cmd = [
282282
"python",
283283
"-m",
284-
"graph_net.torch.collect_stats",
284+
"graph_net.paddle.collect_stats",
285285
f"--device={args.device}",
286286
f"--model-path={root}",
287287
f"--log-prompt={args.log_prompt}",

graph_net/paddle/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def load_converted_list_from_text(file_path):
122122
return [*weight_info, *input_info]
123123

124124

125-
def ConvertToValidNumber(data_type, value):
125+
def convert_to_valid_number(data_type, value):
126126
if value is not None and data_type in [
127127
paddle.float32,
128128
paddle.float16,
@@ -160,10 +160,10 @@ def convert_meta_classes_to_tensors(file_path):
160160
"shape": attrs.get("shape", []),
161161
"dtype": data_type,
162162
"device": attrs.get("device", "gpu"),
163-
"mean": ConvertToValidNumber(data_type, attrs.get("mean", None)),
164-
"std": ConvertToValidNumber(data_type, attrs.get("std", None)),
165-
"min_val": ConvertToValidNumber(data_type, attrs.get("min_val", 0)),
166-
"max_val": ConvertToValidNumber(data_type, attrs.get("max_val", 2)),
163+
"mean": convert_to_valid_number(data_type, attrs.get("mean", None)),
164+
"std": convert_to_valid_number(data_type, attrs.get("std", None)),
165+
"min_val": convert_to_valid_number(data_type, attrs.get("min_val", 0)),
166+
"max_val": convert_to_valid_number(data_type, attrs.get("max_val", 2)),
167167
},
168168
"data": data_value,
169169
"name": attrs.get("name"),

0 commit comments

Comments
 (0)