File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change 1313import logging
1414
1515import onnx
16- from onnx import helper
16+ from onnx import helper , shape_inference
1717
1818from tf2onnx .graph import GraphUtil
1919from tf2onnx import logging , optimizer , constants
@@ -46,6 +46,12 @@ def load_graph(fname, target):
4646 return g , model_proto
4747
4848
49+ def model_shape_inference (onnx_model_proto ):
50+ inferred_model = shape_inference .infer_shapes (onnx_model_proto )
51+ onnx .checker .check_model (inferred_model )
52+ return inferred_model
53+
54+
4955def main ():
5056 args = get_args ()
5157
@@ -64,10 +70,12 @@ def main():
6470
6571 model_proto = helper .make_model (onnx_graph , ** kwargs )
6672
73+ model_proto_inferred = model_shape_inference (model_proto )
74+
6775 # write onnx graph
6876 if args .output :
6977 with open (args .output , "wb" ) as f :
70- f .write (model_proto .SerializeToString ())
78+ f .write (model_proto_inferred .SerializeToString ())
7179
7280
7381if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments