@@ -96,7 +96,8 @@ void setup_input_tensors(
9696    std::vector<at::Tensor> inputs,
9797    c10::intrusive_ptr<TRTEngine> compiled_engine,
9898    bool  cudagraphs_enabled,
99-     bool  need_cudagraphs_record) {
99+     bool  need_cudagraphs_record,
100+     bool  shape_changed) {
100101  //  this is a buffer to store shape tensor input addresses throughout the runtime scope
101102  std::list<std::vector<int64_t >> inputShapeTensorValues;
102103  std::list<at::Tensor> formatted_inputs (compiled_engine->num_io .first );
@@ -117,7 +118,7 @@ void setup_input_tensors(
117118    auto  shape = core::util::toVec (dims);
118119    LOG_DEBUG (" Input Name: " "  Shape: " 
119120
120-     if  (compiled_engine->cuda_engine -> isShapeInferenceIO ( name. c_str ()) ) {
121+     if  (compiled_engine->isShapeInferenceIO [ name] ) {
121122      //  Shape tensor inputs are casted to int64 explicitly.
122123      //  Refer to
123124      //  https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
@@ -145,10 +146,10 @@ void setup_input_tensors(
145146        //  Create a new persistent input buffer
146147        compiled_engine->input_buffers [i] = std::move (formatted_inputs.back ().clone ());
147148      }
148- 
149-       TORCHTRT_CHECK (
150-           compiled_engine->exec_ctx ->setInputShape (name.c_str (), dims), " Error while setting the input shape" 
151- 
149+        if  (shape_changed) { 
150+          TORCHTRT_CHECK (
151+              compiled_engine->exec_ctx ->setInputShape (name.c_str (), dims), " Error while setting the input shape" 
152+       } 
152153      if  (cudagraphs_enabled) {
153154        //  If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
154155        compiled_engine->input_buffers [i].copy_ (formatted_inputs.back (), true );
@@ -217,7 +218,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
217218      compiled_engine->cudagraph .reset ();
218219    }
219220
220-     std::vector<at::Tensor> outputs (compiled_engine-> num_io . second ) ;
221+     std::vector<at::Tensor> outputs;
221222
222223    //  Intialize inputs and outputs to be available throughout the succeeding scopes
223224    { //  Input Setup
@@ -226,10 +227,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
226227        input_profiler_guard =
227228            std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path );
228229      }
229- 
230-       setup_input_tensors (inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
230+       setup_input_tensors (inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, shape_changed);
231231      //  Check if input shapes can be inferred.
232-       int32_t  const  io_size{compiled_engine->cuda_engine -> getNbIOTensors () };
232+       int32_t  const  io_size{compiled_engine->io_size };
233233      std::vector<char  const *> names (io_size);
234234      int32_t  const  nbNames = compiled_engine->exec_ctx ->inferShapes (names.size (), names.data ());
235235      TORCHTRT_CHECK (
@@ -240,6 +240,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
240240    }
241241
242242    { //  Output Setup
243+       bool  new_outputs = false ;
243244      std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
244245      if  (compiled_engine->profile_execution ) {
245246        output_profiler_guard =
@@ -248,64 +249,59 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
248249      if  (can_use_pre_allocated_outputs) {
249250        outputs = compiled_engine->pre_allocated_outputs ;
250251      } else  {
251-         outputs = create_output_tensors (compiled_engine);
252+         if  (compiled_engine->allocated_outputs .size () == 0  or  compiled_engine->unowned_output_tensor  or  shape_changed) {
253+           compiled_engine->allocated_outputs  = create_output_tensors (compiled_engine);
254+           new_outputs = true ;
255+         }
256+         outputs = compiled_engine->allocated_outputs ;
252257      }
253258
254-       for  (auto  output_indices : compiled_engine->out_binding_map ) {
255-         auto  pyt_idx = output_indices.second ;
256-         std::string name = compiled_engine->out_binding_names [pyt_idx];
257-         if  (need_cudagraphs_record) {
258-           //  If we are recording the cuda graph then we need to update the persistent output buffer
259-           compiled_engine->output_buffers [pyt_idx] = std::move (outputs[pyt_idx].clone ());
260-         }
259+       if  (new_outputs) {
260+         for  (auto  output_indices : compiled_engine->out_binding_map ) {
261+           auto  pyt_idx = output_indices.second ;
262+           std::string name = compiled_engine->out_binding_names [pyt_idx];
263+           if  (need_cudagraphs_record) {
264+             //  If we are recording the cuda graph then we need to update the persistent output buffer
265+             compiled_engine->output_buffers [pyt_idx] = std::move (outputs[pyt_idx].clone ());
266+           }
261267
262-         if  (cudagraphs_enabled) {
263-           TORCHTRT_CHECK (
264-               compiled_engine->exec_ctx ->setTensorAddress (
265-                   name.c_str (), compiled_engine->output_buffers [pyt_idx].data_ptr ()),
266-               " Error while setting the output tensor address" 
267-         } else  {
268-           TORCHTRT_CHECK (
269-               compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), outputs[pyt_idx].data_ptr ()),
270-               " Error while setting the output tensor address" 
268+           if  (cudagraphs_enabled) {
269+             TORCHTRT_CHECK (
270+                 compiled_engine->exec_ctx ->setTensorAddress (
271+                     name.c_str (), compiled_engine->output_buffers [pyt_idx].data_ptr ()),
272+                 " Error while setting the output tensor address" 
273+           } else  {
274+             TORCHTRT_CHECK (
275+                 compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), outputs[pyt_idx].data_ptr ()),
276+                 " Error while setting the output tensor address" 
277+           }
271278        }
272279      }
273280    }
274281
275282    auto  current_device_id = -1 ;
276283    if  (inputs.size () > 0 ) {
277284      current_device_id = inputs[0 ].device ().index (); //  Done this way to avoid a call to cudart
278-     } else  if  (outputs.size () > 0 ) {
279-       current_device_id = outputs[0 ].device ().index (); //  Done this way to avoid a call to cudart
280-     }
281- 
282-     compiled_engine->caller_stream  = c10::cuda::getCurrentCUDAStream (current_device_id);
283-     if  (compiled_engine->engine_stream  == c10::cuda::getDefaultCUDAStream (current_device_id)) {
284-       //  Create a new stream if the engine stream is the default stream
285-       compiled_engine->engine_stream  = c10::cuda::getStreamFromPool (false , current_device_id);
285+       if  (current_device_id != compiled_engine->current_device_id ) {
286+         compiled_engine->stream  = c10::cuda::getCurrentCUDAStream (current_device_id);
287+       }
286288    }
287289
288290    { //  Engine Execution (execute on engine stream)
289-       c10::cuda::CUDAStreamGuard stream_guard (compiled_engine->engine_stream );
290291
291292      std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
292293      if  (compiled_engine->profile_execution ) {
293294        enqueue_profiler_guard =
294295            std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path );
295296      }
296297
297-       //  Block engine stream until results are available on caller stream
298-       at::cuda::CUDAEvent caller_exec_complete;
299-       caller_exec_complete.record (compiled_engine->caller_stream );
300-       caller_exec_complete.block (compiled_engine->engine_stream );
301- 
302298      if  (!cudagraphs_enabled) {
303299        //  Direct execution uses the caller buffers directly
304-         compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->engine_stream );
300+         compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->stream );
305301      } else  {
306302        if  (need_cudagraphs_record) {
307303          //  If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
308-           c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream ;
304+           c10::cuda::CUDAStream recording_stream = compiled_engine->stream ;
309305          compiled_engine->cudagraph .capture_begin ();
310306          compiled_engine->exec_ctx ->enqueueV3 (recording_stream);
311307          compiled_engine->cudagraph .capture_end ();
@@ -325,11 +321,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
325321      compiled_engine->pre_allocated_outputs  = create_output_tensors (compiled_engine);
326322    }
327323
328-     //  Block caller stream until engine execution is complete
329-     at::cuda::CUDAEvent trt_exec_complete;
330-     trt_exec_complete.record (compiled_engine->engine_stream );
331-     trt_exec_complete.block (compiled_engine->caller_stream );
332- 
333324    if  (cudagraphs_enabled) {
334325      //  If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
335326      for  (size_t  o = 0 ; o < compiled_engine->output_buffers .size (); o++) {
@@ -354,7 +345,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
354345            std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path );
355346      }
356347
357-       setup_input_tensors (inputs, compiled_engine, false , false );
348+       setup_input_tensors (inputs, compiled_engine, false , false ,  true );
358349      //  Check if input shapes can be inferred.
359350      int32_t  const  io_size{compiled_engine->cuda_engine ->getNbIOTensors ()};
360351      std::vector<char  const *> names (io_size);
@@ -378,40 +369,24 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
378369    auto  current_device_id = -1 ;
379370    if  (inputs.size () > 0 ) {
380371      current_device_id = inputs[0 ].device ().index (); //  Done this way to avoid a call to cudart
381-     } else  {
382-       current_device_id = at::cuda::current_device ();
383-     }
384- 
385-     compiled_engine->caller_stream  = c10::cuda::getCurrentCUDAStream (current_device_id);
386-     if  (compiled_engine->engine_stream  == c10::cuda::getDefaultCUDAStream (current_device_id)) {
387-       //  Create a new stream if the engine stream is the default stream
388-       compiled_engine->engine_stream  = c10::cuda::getStreamFromPool (false , current_device_id);
372+       if  (current_device_id != compiled_engine->current_device_id ) {
373+         compiled_engine->stream  = c10::cuda::getCurrentCUDAStream (current_device_id);
374+       }
389375    }
390376
391377    { //  Engine Execution (execute on engine stream)
392-       c10::cuda::CUDAStreamGuard stream_guard (compiled_engine->engine_stream );
393378
394379      std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
395380      if  (compiled_engine->profile_execution ) {
396381        enqueue_profiler_guard =
397382            std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path );
398383      }
399384
400-       //  Block engine stream until results are available on caller stream
401-       at::cuda::CUDAEvent caller_exec_complete;
402-       caller_exec_complete.record (compiled_engine->caller_stream );
403-       caller_exec_complete.block (compiled_engine->engine_stream );
404- 
405385      //  Direct execution uses the caller buffers directly
406-       compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->engine_stream );
386+       compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->stream );
407387
408388    } //  End engine exeuction (resets to caller stream)
409389
410-     //  Block caller stream until engine execution is complete
411-     at::cuda::CUDAEvent trt_exec_complete;
412-     trt_exec_complete.record (compiled_engine->engine_stream );
413-     trt_exec_complete.block (compiled_engine->caller_stream );
414- 
415390    std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
416391    if  (compiled_engine->profile_execution ) {
417392      output_profiler_guard =
0 commit comments