@@ -93,6 +93,16 @@ const char *SAM::CreateSession(SEG::DL_INIT_PARAM &iParams) {
93
93
94
94
auto input_shape =
95
95
_session->GetInputTypeInfo (0 ).GetTensorTypeAndShapeInfo ().GetShape ();
96
+ // Optional shape check when model has fixed dims (not -1)
97
+ if (input_shape.size () >= 4 && input_shape[2 ] > 0 && input_shape[3 ] > 0 ) {
98
+ const int64_t expectH = _imgSize.at (1 );
99
+ const int64_t expectW = _imgSize.at (0 );
100
+ if (input_shape[2 ] != expectH || input_shape[3 ] != expectW) {
101
+ std::cerr << " [SAM]: Model input (H,W)=(" << input_shape[2 ] << " ," << input_shape[3 ]
102
+ << " ) mismatches configured imgSize (W,H)=(" << _imgSize[0 ] << " ," << _imgSize[1 ] << " )."
103
+ << std::endl;
104
+ }
105
+ }
96
106
auto output_shape =
97
107
_session->GetOutputTypeInfo (0 ).GetTensorTypeAndShapeInfo ().GetShape ();
98
108
auto output_type = _session->GetOutputTypeInfo (0 )
@@ -127,9 +137,9 @@ const char *SAM::RunSession(const cv::Mat &iImg,
127
137
utilities.BlobFromImage (processedImg, blob);
128
138
std::vector<int64_t > inputNodeDims;
129
139
if (_modelType == SEG::SAM_SEGMENT_ENCODER) {
130
- inputNodeDims = {1 , 3 , _imgSize.at (0 ), _imgSize.at (1 )};
140
+ // NCHW: H = imgSize[1], W = imgSize[0]
141
+ inputNodeDims = {1 , 3 , _imgSize.at (1 ), _imgSize.at (0 )};
131
142
} else if (_modelType == SEG::SAM_SEGMENT_DECODER) {
132
- // Input size or SAM decoder model is 256x64x64 for the decoder
133
143
inputNodeDims = {1 , 256 , 64 , 64 };
134
144
}
135
145
TensorProcess (starttime_1, iImg, blob, inputNodeDims, _modelType, oResult,
@@ -329,8 +339,9 @@ char *SAM::WarmUpSession(SEG::MODEL_TYPE _modelType) {
329
339
330
340
float *blob = new float [iImg.total () * 3 ];
331
341
utilities.BlobFromImage (processedImg, blob);
332
- std::vector<int64_t > SAM_input_node_dims = {1 , 3 , _imgSize.at (0 ),
333
- _imgSize.at (1 )};
342
+
343
+ // NCHW: H = imgSize[1], W = imgSize[0]
344
+ std::vector<int64_t > SAM_input_node_dims = {1 , 3 , _imgSize.at (1 ), _imgSize.at (0 )};
334
345
switch (_modelType) {
335
346
case SEG::SAM_SEGMENT_ENCODER: {
336
347
Ort::Value input_tensor = Ort::Value::CreateTensor<float >(
0 commit comments