Skip to content

Commit 5e0d3e8

Browse files
committed
Changed the stream of python runtime to default stream
1 parent 0ca78fd commit 5e0d3e8

File tree

1 file changed

+11
-25
lines changed

1 file changed

+11
-25
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ def __init__(
172172
self._input_buffers: List[torch.Tensor] = []
173173
self._output_buffers: List[torch.Tensor] = []
174174
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
175-
self._caller_stream: Optional[torch.cuda.Stream] = None
176-
self._engine_stream: Optional[torch.cuda.Stream] = None
175+
self._engine_stream: torch.cuda.Stream = torch.cuda.current_stream()
177176
self.output_tensors: Optional[List[torch.Tensor]] = None
178177
self.sync_stream = True
179178

@@ -288,13 +287,7 @@ def setup_engine(self) -> None:
288287
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
289288
# Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
290289
# otherwise, use the caller stream and disable stream synchronization
291-
self._caller_stream = torch.cuda.current_stream()
292-
if self._caller_stream == torch.cuda.default_stream():
293-
self._engine_stream = torch.cuda.Stream()
294-
self.sync_stream = True
295-
else:
296-
self._engine_stream = self._caller_stream
297-
self.sync_stream = False
290+
self._engine_stream = torch.cuda.current_stream()
298291

299292
self.initialized = True
300293
runtime = trt.Runtime(TRT_LOGGER)
@@ -561,9 +554,6 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
561554
else nullcontext()
562555
):
563556

564-
if self.sync_stream:
565-
self._engine_stream.wait_stream(self._caller_stream)
566-
567557
if self.cudagraphs_enabled:
568558
if need_cudagraphs_record:
569559
self.cudagraph = torch.cuda.CUDAGraph()
@@ -593,10 +583,16 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
593583
self.cudagraph.replay() # type: ignore
594584

595585
else:
596-
self.context.execute_async_v3(self._engine_stream.cuda_stream)
586+
import warnings
597587

598-
if self.sync_stream:
599-
self._caller_stream.wait_stream(self._engine_stream)
588+
with warnings.catch_warnings():
589+
try:
590+
self.context.execute_async_v3(
591+
self._engine_stream.cuda_stream
592+
)
593+
except Warning as e:
594+
breakpoint()
595+
print("warning ignored")
600596

601597
if self.use_pre_allocated_outputs:
602598
self.pre_allocated_outputs = self.create_output_tensors()
@@ -651,22 +647,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
651647
if self.profiling_enabled
652648
else nullcontext()
653649
):
654-
self._caller_stream = torch.cuda.current_stream()
655-
if (
656-
self._engine_stream == torch.cuda.default_stream()
657-
or self._engine_stream is None
658-
):
659-
self._engine_stream = torch.cuda.Stream()
660-
661-
self._engine_stream.wait_stream(self._caller_stream)
662650

663651
with torch.cuda.stream(self._engine_stream):
664652
self.context.execute_async_v3(
665653
self._engine_stream.cuda_stream
666654
) # The OutputAllocator is called by execute_async_v3()
667655

668-
self._caller_stream.wait_stream(self._engine_stream)
669-
670656
with (
671657
torch.autograd.profiler.record_function(
672658
"PythonTorchTensorRTModule:ProcessOutputs"

0 commit comments

Comments
 (0)