Skip to content

Commit 832a7f4

Browse files
committed
Add EP API Stream support
1 parent 5828e10 commit 832a7f4

10 files changed

+377
-14
lines changed

plugin_execution_providers/tensorrt/cuda_allocator.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ using DeviceId = int16_t;
1010
struct CUDAAllocator : OrtAllocator {
1111
CUDAAllocator(const OrtMemoryInfo* mem_info, DeviceId device_id) : mem_info_(mem_info), device_id_(device_id) {
1212
OrtAllocator::version = ORT_API_VERSION;
13-
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) {
14-
return static_cast<CUDAAllocator*>(this_)->Alloc(size);
15-
};
13+
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAAllocator*>(this_)->Alloc(size); };
1614
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAAllocator*>(this_)->Free(p); };
1715
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const CUDAAllocator*>(this_)->Info(); };
16+
OrtAllocator::Reserve = nullptr;
17+
OrtAllocator::GetStats = nullptr;
18+
OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP.
1819
}
1920
// TODO: Handle destructor
2021
//~CUDAAllocator();
@@ -41,6 +42,9 @@ struct CUDAPinnedAllocator : OrtAllocator {
4142
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<CUDAPinnedAllocator*>(this_)->Alloc(size); };
4243
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<CUDAPinnedAllocator*>(this_)->Free(p); };
4344
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const CUDAPinnedAllocator*>(this_)->Info(); };
45+
OrtAllocator::Reserve = nullptr;
46+
OrtAllocator::GetStats = nullptr;
47+
OrtAllocator::AllocOnStream = nullptr;
4448
}
4549
// TODO: Handle destructor
4650
//~CUDAPinnedAllocator();

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "tensorrt_execution_provider.h"
1515
#include "cuda_allocator.h"
1616
#include "onnx_ctx_model_helper.h"
17+
#include "tensorrt_execution_provider_stream_support.h"
1718
#include "onnx/onnx_pb.h"
1819
#include "cuda/unary_elementwise_ops_impl.h"
1920
#include "ep_utils.h"
@@ -1960,6 +1961,30 @@ const char* ORT_API_CALL TensorrtExecutionProvider::GetNameImpl(const OrtEp* thi
19601961
return ep->name_.c_str();
19611962
}
19621963

1964+
OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr,
1965+
_In_ const OrtMemoryDevice* memory_device,
1966+
_Outptr_ OrtSyncStreamImpl** stream) noexcept {
1967+
// A per-session OrtSyncStreamImpl can be created here if the session options affect the implementation.
1968+
// Logging of any issues should use logger_ which is the session logger.
1969+
1970+
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
1971+
1972+
// we only create streams for the default device memory.
1973+
if (auto mem_type = ep->factory_.ep_api.MemoryDevice_GetMemoryType(memory_device);
1974+
mem_type != OrtDeviceMemoryType_DEFAULT) {
1975+
std::string error = "Invalid OrtMemoryDevice. Expected OrtDeviceMemoryType_DEFAULT(0). Got ";
1976+
error += std::to_string(mem_type);
1977+
return ep->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, error.c_str());
1978+
}
1979+
1980+
auto device_id = ep->factory_.ep_api.MemoryDevice_GetDeviceId(memory_device);
1981+
1982+
auto sync_stream = std::make_unique<TrtSyncStreamImpl>(ep->factory_, ep, device_id, nullptr);
1983+
*stream = sync_stream.release();
1984+
1985+
return nullptr;
1986+
}
1987+
19631988
/**
19641989
* Refit the weight-stripped engine
19651990
*/
@@ -2070,6 +2095,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
20702095
GetCapability = GetCapabilityImpl;
20712096
Compile = CompileImpl;
20722097
ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl;
2098+
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;
20732099

20742100
// Initialize the execution provider.
20752101
auto status = ort_api.Logger_LogMessage(&logger_,
@@ -2158,7 +2184,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
21582184
force_timing_cache_match_ = info_.force_timing_cache;
21592185
detailed_build_log_ = info_.detailed_build_log;
21602186
dump_ep_context_model_ = info_.dump_ep_context_model;
2161-
dump_ep_context_model_ = true;
2187+
//dump_ep_context_model_ = true;
21622188
ep_context_file_path_ = info_.ep_context_file_path;
21632189
ep_context_embed_mode_ = info_.ep_context_embed_mode;
21642190
enable_engine_cache_for_ep_context_model();
@@ -2378,7 +2404,6 @@ void ORT_API_CALL TensorrtExecutionProvider::ReleaseNodeComputeInfosImpl(OrtEp*
23782404
}
23792405
}
23802406

2381-
23822407
//
23832408
// Implementation of TRTEpNodeComputeInfo
23842409
//
@@ -2487,7 +2512,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
24872512
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream);
24882513

24892514
//cudaStream_t stream;
2490-
cudaStreamCreate(&stream);
2515+
//cudaStreamCreate(&stream);
24912516

24922517
// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
24932518
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even
@@ -3053,6 +3078,9 @@ void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void*
30533078
// Do nothing for here.
30543079
}
30553080

3081+
//
3082+
// Implementation of TRTEpEpContextNodeComputeInfo
3083+
//
30563084
TRTEpEpContextNodeComputeInfo::TRTEpEpContextNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(ep) {
30573085
ort_version_supported = ORT_API_VERSION;
30583086
CreateState = CreateStateImpl;

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename";
227227

228228
/// <summary>
229229
///
230-
/// Plugin TensorRT EP
230+
/// Plugin TensorRT EP implementing OrtEp.
231231
///
232232
/// </summary>
233233
struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
@@ -311,6 +311,8 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
311311
std::unordered_map<std::string, std::string> trt_node_name_with_precision_;
312312
std::unordered_map<std::string, std::unordered_map<std::string, float>> dynamic_range_map_;
313313
std::unordered_map<std::string, std::string> cache_suffix_;
314+
bool external_stream_ = false;
315+
cudaStream_t stream_ = nullptr;
314316

315317
private:
316318
static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept;
@@ -323,12 +325,11 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs {
323325
static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos,
324326
size_t num_node_compute_infos) noexcept;
325327

326-
OrtStatus* CreateEpContextNodes(gsl::span<const OrtNode*> fused_nodes,
327-
/*out*/ gsl::span<OrtNode*> ep_context_nodes);
328+
static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr,
329+
_In_ const OrtMemoryDevice* memory_device,
330+
_Outptr_ OrtSyncStreamImpl** stream) noexcept;
328331

329332
mutable TensorrtExecutionProviderInfo info_;
330-
bool external_stream_ = false;
331-
cudaStream_t stream_ = nullptr;
332333
int max_partition_iterations_ = 1000;
333334
size_t min_subgraph_size_ = 1;
334335
size_t max_workspace_size_ = 1 << 30; // 1GB
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "tensorrt_execution_provider_stream_support.h"
5+
#include "tensorrt_provider_factory.h"
6+
#include "tensorrt_execution_provider.h"
7+
8+
#include "cuda/cuda_common.h"
9+
#include "cuda/cuda_call.h"
10+
11+
//
12+
// TrtSyncStreamImpl implementation
13+
//
14+
15+
TrtSyncStreamImpl::TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, const OrtEp* ep, uint32_t device_id, const OrtKeyValuePairs* /*stream_options*/)
16+
: ApiPtrs(factory), ep_{ep}, factory_{&factory} {
17+
ort_version_supported = ORT_API_VERSION;
18+
CreateNotification = CreateNotificationImpl;
19+
GetHandle = GetHandleImpl;
20+
Flush = FlushImpl;
21+
OnSessionRunEnd = OnSessionRunEndImpl;
22+
Release = ReleaseImpl;
23+
24+
const TensorrtExecutionProvider* trt_ep = static_cast<const TensorrtExecutionProvider*>(ep_);
25+
if (trt_ep->external_stream_) {
26+
stream_ = trt_ep->stream_;
27+
own_stream_ = false;
28+
} else {
29+
CUDA_CALL_THROW(cudaSetDevice(static_cast<int>(device_id)));
30+
cudaStream_t stream = nullptr;
31+
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
32+
stream_ = stream;
33+
own_stream_ = true;
34+
}
35+
}
36+
37+
/*static*/
38+
OrtStatus* ORT_API_CALL TrtSyncStreamImpl::CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr,
39+
_Outptr_ OrtSyncNotificationImpl** notification) noexcept {
40+
auto& impl = *static_cast<TrtSyncStreamImpl*>(this_ptr);
41+
42+
std::unique_ptr<TrtSyncNotificationImpl> trt_sync_notification;
43+
RETURN_IF_ERROR(TrtSyncNotificationImpl::Create(impl.stream_, impl, trt_sync_notification));
44+
45+
*notification = trt_sync_notification.release();
46+
return nullptr;
47+
}
48+
49+
/*static*/
50+
void* ORT_API_CALL TrtSyncStreamImpl::GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept {
51+
auto& impl = *static_cast<TrtSyncStreamImpl*>(this_ptr);
52+
return static_cast<void*>(impl.stream_);
53+
}
54+
55+
/*static*/
56+
OrtStatus* ORT_API_CALL TrtSyncStreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept {
57+
auto& impl = *static_cast<TrtSyncStreamImpl*>(this_ptr);
58+
59+
// only flush when we own the stream, not external
60+
if (impl.own_stream_) CUDA_CALL_THROW(cudaStreamSynchronize(static_cast<cudaStream_t>(impl.stream_)));
61+
return nullptr;
62+
}
63+
64+
/*static*/
65+
OrtStatus* ORT_API_CALL TrtSyncStreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept {
66+
return nullptr;
67+
}
68+
69+
// callback for EP library to release any internal state
70+
/*static*/
71+
void ORT_API_CALL TrtSyncStreamImpl::ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept {
72+
delete static_cast<TrtSyncStreamImpl*>(this_ptr);
73+
}
74+
75+
//
76+
// Notification support
77+
//
78+
79+
/*static*/
80+
OrtStatus* TrtSyncNotificationImpl::Create(cudaStream_t stream, const ApiPtrs& apis,
81+
std::unique_ptr<TrtSyncNotificationImpl>& notification){
82+
auto trt_sync_notification = std::make_unique<TrtSyncNotificationImpl>(stream, apis);
83+
CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&trt_sync_notification->event_, cudaEventDisableTiming));
84+
85+
notification = std::move(trt_sync_notification);
86+
87+
return nullptr;
88+
}
89+
90+
/*static*/
91+
OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept {
92+
auto& impl = *static_cast<TrtSyncNotificationImpl*>(this_ptr);
93+
CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_));
94+
95+
return nullptr;
96+
}
97+
98+
/*static*/
99+
OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr,
100+
_In_ OrtSyncStream* stream) noexcept {
101+
auto& impl = *static_cast<TrtSyncNotificationImpl*>(this_ptr);
102+
void* handle = impl.ort_api.SyncStream_GetHandle(stream);
103+
CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast<cudaStream_t>(handle), impl.event_));
104+
105+
return nullptr;
106+
}
107+
108+
/*static*/
109+
OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept {
110+
auto& impl = *static_cast<TrtSyncNotificationImpl*>(this_ptr);
111+
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_));
112+
113+
return nullptr;
114+
}
115+
116+
/*static*/
117+
void ORT_API_CALL TrtSyncNotificationImpl::ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept {
118+
delete static_cast<TrtSyncNotificationImpl*>(this_ptr);
119+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "onnxruntime_c_api.h"
7+
#include "tensorrt_provider_factory.h"
8+
#include "ep_utils.h"
9+
10+
#include <cuda_runtime_api.h>
11+
12+
//
13+
// Class implementing Stream support for synchronization.
14+
//
15+
struct TrtSyncStreamImpl : public OrtSyncStreamImpl, public ApiPtrs {
16+
TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory,
17+
const OrtEp* ep,
18+
uint32_t device_id,
19+
const OrtKeyValuePairs* /*stream_options*/);
20+
21+
private:
22+
static OrtStatus* ORT_API_CALL CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr,
23+
_Outptr_ OrtSyncNotificationImpl** sync_notification) noexcept;
24+
static void* ORT_API_CALL GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept;
25+
static OrtStatus* ORT_API_CALL FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept;
26+
static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept;
27+
static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept;
28+
29+
// EP instance if the stream is being created internally for inferencing.
30+
// nullptr when the stream is created outside of an inference session for data copies.
31+
const OrtEp* ep_;
32+
TensorrtExecutionProviderFactory* factory_{nullptr};
33+
34+
cudaStream_t stream_{nullptr};
35+
bool own_stream_{true};
36+
};
37+
38+
//
39+
// Class implementing synchronization notification support.
40+
//
41+
struct TrtSyncNotificationImpl : public OrtSyncNotificationImpl, public ApiPtrs {
42+
static OrtStatus* Create(cudaStream_t stream, const ApiPtrs& apis,
43+
std::unique_ptr<TrtSyncNotificationImpl>& notification);
44+
45+
TrtSyncNotificationImpl(cudaStream_t stream, const ApiPtrs& apis) : stream_(stream), ApiPtrs(apis) {
46+
ort_version_supported = ORT_API_VERSION;
47+
Activate = ActivateImpl;
48+
Release = ReleaseImpl;
49+
WaitOnDevice = WaitOnDeviceImpl;
50+
WaitOnHost = WaitOnHostImpl;
51+
}
52+
53+
private:
54+
static OrtStatus* ORT_API_CALL ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept;
55+
static OrtStatus* ORT_API_CALL WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr,
56+
_In_ OrtSyncStream* stream) noexcept;
57+
static OrtStatus* ORT_API_CALL WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept;
58+
static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept;
59+
60+
cudaStream_t& stream_;
61+
cudaEvent_t event_;
62+
};

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl
249249
}
250250

251251
bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
252-
return false;
252+
return true;
253253
}
254254

255255
// To make symbols visible on macOS/iOS

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
5555

5656
static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept;
5757

58-
void SetGPUDataTransfer(std::unique_ptr<TRTEpDataTransfer> gpu_data_transfer);
59-
6058
const std::string ep_name_; // EP name
6159
const std::string vendor_{"Nvidia"}; // EP vendor name
6260
const std::string ep_version_{"0.1.0"}; // EP version

plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@ std::conditional_t<THRW, void, OrtStatus*> CudaCall(
6060
//ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line);
6161

6262
#define CUDA_CALL(expr) (CudaCall<cudaError, false>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
63+
#define CUDA_CALL_THROW(expr) (CudaCall<cudaError, true>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))

0 commit comments

Comments
 (0)