5
5
from onnx import helper , numpy_helper
6
6
from ultralytics import YOLO
7
7
8
+
8
9
def convert_pt_to_onnx (pt_model_path , onnx_model_path = None ):
9
10
if onnx_model_path is None :
10
11
onnx_model_path = pt_model_path .replace ('.pt' , '.onnx' )
@@ -33,10 +34,8 @@ def onnx_to_json(model_path, output_json_path):
33
34
34
35
layer_info = []
35
36
36
- # Extract input information from ONNX model
37
37
input_info = {}
38
38
for input in model .graph .input :
39
- # Skip initializers (they are weights, not actual inputs)
40
39
if input .name in initializers_dict :
41
40
continue
42
41
@@ -45,9 +44,8 @@ def onnx_to_json(model_path, output_json_path):
45
44
"shape" : [dim .dim_value for dim in input .type .tensor_type .shape .dim ],
46
45
"data_type" : input .type .tensor_type .elem_type
47
46
}
48
- break # Take the first actual input
47
+ break
49
48
50
- # Create input layer with proper information
51
49
input_layer = {
52
50
"index" : 0 ,
53
51
"name" : input_info .get ("name" , "input_1" ),
@@ -67,12 +65,11 @@ def onnx_to_json(model_path, output_json_path):
67
65
"name" : node .name .replace ('/' , '_' ),
68
66
"type" : node .op_type ,
69
67
"attributes" : {},
70
- "inputs" : [] # Add inputs information
68
+ "inputs" : []
71
69
}
72
70
73
- # Add input connections
74
71
for input_name in node .input :
75
- if input_name not in initializers_dict : # Only track layer connections, not weights
72
+ if input_name not in initializers_dict :
76
73
layer_data ["inputs" ].append (input_name .replace ('/' , '_' ))
77
74
78
75
for attr in node .attribute :
@@ -94,29 +91,44 @@ def onnx_to_json(model_path, output_json_path):
94
91
elif attr .name == "strides" :
95
92
layer_data ["strides" ] = attr_value
96
93
97
- node_init = []
98
- for input_name in node .input :
99
- if input_name in initializers_dict :
100
- node_init .append (initializers_dict [input_name ])
101
-
102
- if len (node_init ) == 1 :
103
- init = node_init [0 ]
104
- if len (init ["dims" ]) == 0 or (len (init ["dims" ]) == 1 and init ["dims" ][0 ] == 1 ):
105
- layer_data ["value" ] = init ["values" ] if len (init ["dims" ]) == 0 else init ["values" ][0 ]
106
- else :
107
- layer_data ["weights" ] = init ["values" ]
108
- elif len (node_init ) > 1 :
109
- weights = []
110
- for init in node_init [:- 1 ]:
111
- if len (init ["dims" ]) > 0 :
112
- weights .extend (init ["values" ]) if isinstance (init ["values" ][0 ], list ) else weights .append (
113
- init ["values" ])
114
-
115
- if weights :
116
- layer_data ["weights" ] = weights
117
-
118
- if len (node_init [- 1 ]["dims" ]) == 1 :
119
- layer_data ["bias" ] = node_init [- 1 ]["values" ]
94
+ if node .op_type == "BatchNormalization" :
95
+ bn_params = []
96
+ for input_name in node .input :
97
+ if input_name in initializers_dict :
98
+ bn_params .append (initializers_dict [input_name ])
99
+
100
+ if len (bn_params ) >= 4 :
101
+ layer_data ["scale" ] = bn_params [0 ]["values" ]
102
+ layer_data ["bias" ] = bn_params [1 ]["values" ]
103
+ layer_data ["mean" ] = bn_params [2 ]["values" ]
104
+ layer_data ["var" ] = bn_params [3 ]["values" ]
105
+
106
+ layer_data ["weights" ] = []
107
+
108
+ else :
109
+ node_init = []
110
+ for input_name in node .input :
111
+ if input_name in initializers_dict :
112
+ node_init .append (initializers_dict [input_name ])
113
+
114
+ if len (node_init ) == 1 :
115
+ init = node_init [0 ]
116
+ if len (init ["dims" ]) == 0 or (len (init ["dims" ]) == 1 and init ["dims" ][0 ] == 1 ):
117
+ layer_data ["value" ] = init ["values" ] if len (init ["dims" ]) == 0 else init ["values" ][0 ]
118
+ else :
119
+ layer_data ["weights" ] = init ["values" ]
120
+ elif len (node_init ) > 1 :
121
+ weights = []
122
+ for init in node_init [:- 1 ]:
123
+ if len (init ["dims" ]) > 0 :
124
+ weights .extend (init ["values" ]) if isinstance (init ["values" ][0 ], list ) else weights .append (
125
+ init ["values" ])
126
+
127
+ if weights :
128
+ layer_data ["weights" ] = weights
129
+
130
+ if len (node_init [- 1 ]["dims" ]) == 1 :
131
+ layer_data ["bias" ] = node_init [- 1 ]["values" ]
120
132
121
133
layer_info .append (layer_data )
122
134
@@ -144,7 +156,7 @@ def default(self, obj):
144
156
145
157
BASE_DIR = os .path .dirname (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
146
158
147
- MODEL_PATH = os .path .join (BASE_DIR , 'docs\\ models' , 'GoogLeNet .onnx' )
148
- MODEL_DATA_PATH = os .path .join (BASE_DIR , 'docs\\ jsons' , 'googlenet_onnx_model .json' )
159
+ MODEL_PATH = os .path .join (BASE_DIR , 'docs\\ models' , 'densenet121_Opset16 .onnx' )
160
+ MODEL_DATA_PATH = os .path .join (BASE_DIR , 'docs\\ jsons' , 'densenet121_Opset16_onnx_model .json' )
149
161
150
162
onnx_to_json (MODEL_PATH , MODEL_DATA_PATH )
0 commit comments