1
1
import argparse
2
2
import os
3
+ import re
3
4
import sys
4
5
import ast
5
6
import math
@@ -118,6 +119,24 @@ def update_op_stats(self, op_name, op_dtype):
118
119
)
119
120
self .op_stats [op_name ].count += 1
120
121
122
+ def parse_pir_value_dtypes (self , type_str ):
123
+ short_form2dtype = {
124
+ "f32" : "float32" ,
125
+ "f16" : "float16" ,
126
+ "bf16" : "bfloat16" ,
127
+ "i64" : "int64" ,
128
+ }
129
+ # type_str: "vec[tensor<1x18x13x9xf32>,tensor<1x9x13x9xf32>]"
130
+ matches = re .findall (r"tensor<([^>]+)>" , type_str )
131
+ dtype_strs = []
132
+ for s in matches :
133
+ parts = s .split ("x" )
134
+ assert len (parts ) > 0
135
+
136
+ dtype = parts [- 1 ].lower ()
137
+ dtype_strs .append (short_form2dtype [dtype ])
138
+ return dtype_strs
139
+
121
140
def __call__ (self , program ):
122
141
assert isinstance (program , paddle .base .libpaddle .pir .Program )
123
142
@@ -129,22 +148,38 @@ def __call__(self, program):
129
148
op_name = None
130
149
op_dtype = None
131
150
if op .name () == "pd_op.data" :
151
+ op_name = "data"
132
152
op_attrs = op .attrs ()
133
153
op_dtype = op_attrs ["dtype" ]
134
154
self .input_dict [op_attrs ["name" ]] = {
135
155
"dtype" : str (op_dtype ).replace ("paddle." , "" ),
136
156
"shape" : op_attrs ["shape" ],
137
157
}
138
- elif not op .name ().startswith ("builtin ." ):
158
+ elif op .name ().startswith ("pd_op ." ):
139
159
self .num_ops += 1
140
160
op_name = op .name ().replace ("pd_op." , "" )
141
- if len (op .results ()) > 0 :
142
- op_dtype = op .results ()[0 ].dtype
143
-
144
- if op_name is not None :
145
- self .update_op_stats (op_name , op_dtype )
146
- elif op_dtype is None :
147
- self .num_ops_misses_dtypes += 1
161
+ try :
162
+ if len (op .results ()) > 0 :
163
+ out = op .results ()[0 ]
164
+ if out .is_dense_tensor_type ():
165
+ op_dtype = out .dtype
166
+ else :
167
+ # for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined
168
+ if op_name in ["split" , "split_with_num" , "meshgrid" ]:
169
+ op_dtype = self .parse_pir_value_dtypes (
170
+ str (out .type ())
171
+ )[0 ]
172
+ else :
173
+ assert False , f"Unsupport op: { op } "
174
+ except Exception :
175
+ if self .num_ops_misses_dtypes == 0 :
176
+ print (f"dtype inference failed for { op_name } " )
177
+ if op_dtype is not None :
178
+ self .update_op_stats (op_name , op_dtype )
179
+ else :
180
+ self .num_ops_misses_dtypes += 1
181
+ elif not op .name ().startswith ("builtin." ):
182
+ assert False , f"Unrecognized op: { op } "
148
183
149
184
if self .num_ops_misses_dtypes > 0 :
150
185
self .is_complete = False
@@ -281,7 +316,7 @@ def main(args):
281
316
cmd = [
282
317
"python" ,
283
318
"-m" ,
284
- "graph_net.torch .collect_stats" ,
319
+ "graph_net.paddle .collect_stats" ,
285
320
f"--device={ args .device } " ,
286
321
f"--model-path={ root } " ,
287
322
f"--log-prompt={ args .log_prompt } " ,
0 commit comments