2020 */
2121package io .bioimage .modelrunner .onnx ;
2222
23+ import io .bioimage .modelrunner .apposed .appose .Types ;
2324import io .bioimage .modelrunner .engine .DeepLearningEngineInterface ;
2425import io .bioimage .modelrunner .exceptions .LoadModelException ;
2526import io .bioimage .modelrunner .exceptions .RunModelException ;
2627import io .bioimage .modelrunner .onnx .tensor .ImgLib2Builder ;
2728import io .bioimage .modelrunner .onnx .tensor .TensorBuilder ;
2829import io .bioimage .modelrunner .tensor .Tensor ;
2930import net .imglib2 .RandomAccessibleInterval ;
30- import net .imglib2 .img . array . ArrayImgs ;
31- import net .imglib2 .type .numeric .real . FloatType ;
31+ import net .imglib2 .type . NativeType ;
32+ import net .imglib2 .type .numeric .RealType ;
3233
33- import java .io .File ;
3434import java .util .ArrayList ;
3535import java .util .Iterator ;
3636import java .util .LinkedHashMap ;
@@ -85,6 +85,7 @@ public OnnxInterface()
8585 {
8686 }
8787
88+ /**
8889 public static void main(String args[]) throws LoadModelException, RunModelException {
8990 String folderName = "/home/carlos/git/deep-icy/models/NucleiSegmentationBoundaryModel_27112023_190556";
9091 String source = folderName + File.separator + "weights.onnx";
@@ -104,6 +105,7 @@ public static void main(String args[]) throws LoadModelException, RunModelExcept
104105 oi.run(inps, outs);
105106 System.out.println(false);
106107 }
108+ */
107109
108110 /**
109111 * {@inheritDoc}
@@ -132,10 +134,12 @@ public void loadModel(String modelFolder, String modelSource) throws LoadModelEx
132134 *
133135 * Run a Onnx model on the data provided by the {@link Tensor} input list
134136 * and modifies the output list with the results obtained
137+ * @throws RunModelException
135138 *
136139 */
137140 @ Override
138- public void run (List <Tensor <?>> inputTensors , List <Tensor <?>> outputTensors ) throws RunModelException {
141+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
142+ void run (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
139143 Result output ;
140144 LinkedHashMap <String , OnnxTensor > inputMap = new LinkedHashMap <String , OnnxTensor >();
141145 Iterator <String > inputNames = session .getInputNames ().iterator ();
@@ -160,10 +164,47 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) thr
160164 for (OnnxTensor tt : inputMap .values ()) {
161165 tt .close ();
162166 }
163- for (Object tt : output ) {
164- tt = null ;
167+ output .close ();
168+ }
169+
170+ @ Override
171+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
172+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
173+ Result output ;
174+ LinkedHashMap <String , OnnxTensor > inputMap = new LinkedHashMap <String , OnnxTensor >();
175+ Iterator <String > inputNames = session .getInputNames ().iterator ();
176+ try {
177+ for (RandomAccessibleInterval <T > tt : inputs ) {
178+ OnnxTensor inT = TensorBuilder .build (tt , env );
179+ inputMap .put (inputNames .next (), inT );
180+ }
181+ output = session .run (inputMap );
182+ } catch (OrtException ex ) {
183+ for (OnnxTensor tt : inputMap .values ()) {
184+ tt .close ();
185+ }
186+ throw new RunModelException ("Error trying to run an Onnx model."
187+ + System .lineSeparator () + Types .stackTrace (ex ));
188+ }
189+ for (OnnxTensor tt : inputMap .values ()) {
190+ tt .close ();
191+ }
192+
193+ // Fill the agnostic output tensors list with data from the inference result
194+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
195+ for (int i = 0 ; i < output .size (); i ++) {
196+ try {
197+ rais .add (ImgLib2Builder .build (output .get (i ).getValue ()));
198+ output .get (i ).close ();
199+ } catch (IllegalArgumentException | OrtException e ) {
200+ for (int j = i ; j < output .size (); j ++)
201+ output .get (j ).close ();
202+ output .close ();
203+ throw new RunModelException ("Error converting tensor into RAI" + Types .stackTrace (e ));
204+ }
165205 }
166206 output .close ();
207+ return rais ;
167208 }
168209
169210 /**
@@ -179,17 +220,22 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) thr
179220 * @throws RunModelException If the number of tensors expected is not the same
180221 * as the number of Tensors outputed by the model
181222 */
182- public static void fillOutputTensors (Result onnxTensors , List <Tensor <?>> outputTensors ) throws RunModelException {
223+ public static <T extends RealType <T > & NativeType <T >>
224+ void fillOutputTensors (Result onnxTensors ,
225+ List <Tensor <T >> outputTensors ) throws RunModelException {
183226 if (onnxTensors .size () != outputTensors .size ())
184227 throw new RunModelException (onnxTensors .size (), outputTensors .size ());
185228 int cc = 0 ;
186229 for (Tensor tt : outputTensors ) {
187230 try {
188- tt .setData (ImgLib2Builder .build (onnxTensors .get (cc ++).getValue ()));
231+ tt .setData (ImgLib2Builder .build (onnxTensors .get (cc ).getValue ()));
232+ onnxTensors .get (cc ).close ();
233+ cc ++;
189234 } catch (IllegalArgumentException | OrtException e ) {
190- e .printStackTrace ();
191- throw new RunModelException ("Unable to recover value of output tensor: " + tt .getName ()
192- + System .lineSeparator () + e .getCause ().toString ());
235+ for (int j = cc ; j < onnxTensors .size (); j ++)
236+ onnxTensors .get (j ).close ();
237+ onnxTensors .close ();
238+ throw new RunModelException ("Error converting tensor '" + tt .getName () + "' into RAI" + Types .stackTrace (e ));
193239 }
194240 }
195241 }
0 commit comments