@@ -172,8 +172,7 @@ def __init__(
172172        self ._input_buffers : List [torch .Tensor ] =  []
173173        self ._output_buffers : List [torch .Tensor ] =  []
174174        self .cudagraph : Optional [torch .cuda .CUDAGraph ] =  None 
175-         self ._caller_stream : Optional [torch .cuda .Stream ] =  None 
176-         self ._engine_stream : Optional [torch .cuda .Stream ] =  None 
175+         self ._engine_stream : torch .cuda .Stream  =  torch .cuda .current_stream ()
177176        self .output_tensors : Optional [List [torch .Tensor ]] =  None 
178177        self .sync_stream  =  True 
179178
@@ -288,13 +287,7 @@ def setup_engine(self) -> None:
288287        ), f"TensorRT engine was not built to target current platform (target: { self .target_platform } { Platform .current_platform ()}  
289288        # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream 
290289        # otherwise, use the caller stream and disable stream synchronization 
291-         self ._caller_stream  =  torch .cuda .current_stream ()
292-         if  self ._caller_stream  ==  torch .cuda .default_stream ():
293-             self ._engine_stream  =  torch .cuda .Stream ()
294-             self .sync_stream  =  True 
295-         else :
296-             self ._engine_stream  =  self ._caller_stream 
297-             self .sync_stream  =  False 
290+         self ._engine_stream  =  torch .cuda .current_stream ()
298291
299292        self .initialized  =  True 
300293        runtime  =  trt .Runtime (TRT_LOGGER )
@@ -561,9 +554,6 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
561554                else  nullcontext ()
562555            ):
563556
564-                 if  self .sync_stream :
565-                     self ._engine_stream .wait_stream (self ._caller_stream )
566- 
567557                if  self .cudagraphs_enabled :
568558                    if  need_cudagraphs_record :
569559                        self .cudagraph  =  torch .cuda .CUDAGraph ()
@@ -593,10 +583,16 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
593583                    self .cudagraph .replay ()  # type: ignore 
594584
595585                else :
596-                     self . context . execute_async_v3 ( self . _engine_stream . cuda_stream ) 
586+                     import   warnings 
597587
598-                 if  self .sync_stream :
599-                     self ._caller_stream .wait_stream (self ._engine_stream )
588+                     with  warnings .catch_warnings ():
589+                         try :
590+                             self .context .execute_async_v3 (
591+                                 self ._engine_stream .cuda_stream 
592+                             )
593+                         except  Warning  as  e :
594+                             breakpoint ()
595+                             print ("warning ignored" )
600596
601597            if  self .use_pre_allocated_outputs :
602598                self .pre_allocated_outputs  =  self .create_output_tensors ()
@@ -651,22 +647,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
651647                if  self .profiling_enabled 
652648                else  nullcontext ()
653649            ):
654-                 self ._caller_stream  =  torch .cuda .current_stream ()
655-                 if  (
656-                     self ._engine_stream  ==  torch .cuda .default_stream ()
657-                     or  self ._engine_stream  is  None 
658-                 ):
659-                     self ._engine_stream  =  torch .cuda .Stream ()
660- 
661-                 self ._engine_stream .wait_stream (self ._caller_stream )
662650
663651                with  torch .cuda .stream (self ._engine_stream ):
664652                    self .context .execute_async_v3 (
665653                        self ._engine_stream .cuda_stream 
666654                    )  # The OutputAllocator is called by execute_async_v3() 
667655
668-                 self ._caller_stream .wait_stream (self ._engine_stream )
669- 
670656            with  (
671657                torch .autograd .profiler .record_function (
672658                    "PythonTorchTensorRTModule:ProcessOutputs" 
0 commit comments