Skip to content

Commit da0f9c6

Browse files
committed
Add support for dump and run EP Context model
1 parent 7467c65 commit da0f9c6

File tree

4 files changed

+1055
-430
lines changed

4 files changed

+1055
-430
lines changed

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 153 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "ep_utils.h"
99
#include "onnx_ctx_model_helper.h"
10+
#include "onnx/onnx_pb.h"
1011

1112
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
1213

@@ -28,7 +29,7 @@ bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o
2829

2930
const char* op_type = nullptr;
3031
RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type));
31-
if (node != nullptr && op_type == "EPContext") {
32+
if (node != nullptr && std::string(op_type) == "EPContext") {
3233
return true;
3334
}
3435
}
@@ -85,21 +86,21 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca
8586
std::array<OrtOpAttr*, 4> attributes = {};
8687
DeferOrtRelease<OrtOpAttr> defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr);
8788

88-
RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[0]));
89+
RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, sizeof(int64_t), ORT_OP_ATTR_INT, &attributes[0]));
8990

9091
std::string engine_data_str = "";
9192
if (embed_mode) {
9293
if (size > 0) {
9394
engine_data_str.assign(engine_data, size);
9495
}
9596
RETURN_IF_ERROR(
96-
ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1]));
97+
ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), engine_data_str.size(), ORT_OP_ATTR_STRING, &attributes[1]));
9798
} else {
98-
RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1]));
99+
RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), engine_cache_path.size(), ORT_OP_ATTR_STRING, &attributes[1]));
99100
}
100101

101102

102-
ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[2]);
103+
ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), compute_capability.size(), ORT_OP_ATTR_STRING, &attributes[2]);
103104
ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1,
104105
ORT_OP_ATTR_STRING, &attributes[3]);
105106

@@ -111,6 +112,153 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca
111112
return nullptr;
112113
}
113114

115+
OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) {
116+
/*
117+
if (!ValidateEPCtxNode(graph)) {
118+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
119+
}
120+
*/
121+
122+
size_t num_nodes = 0;
123+
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes));
124+
125+
std::vector<const OrtNode*> nodes(num_nodes);
126+
RETURN_IF_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size()));
127+
128+
auto node = nodes[0];
129+
130+
size_t num_node_attributes = 0;
131+
RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(node, &num_node_attributes));
132+
133+
/*
134+
std::vector<const OrtOpAttr*> node_attributes(num_node_attributes);
135+
RETURN_IF_ERROR(ort_api.Node_GetAttributes(node, node_attributes.data(), node_attributes.size()));
136+
*/
137+
138+
const OrtOpAttr* node_attr = nullptr;
139+
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "embed_mode", &node_attr));
140+
const int64_t embed_mode = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->i();
141+
142+
// Only make path checks if model not provided as byte buffer
143+
//bool make_secure_path_checks = !GetModelPath(graph_viewer).empty();
144+
bool make_secure_path_checks = false;
145+
146+
if (embed_mode) {
147+
// Get engine from byte stream.
148+
node_attr = nullptr;
149+
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr));
150+
const std::string& context_binary = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->s();
151+
152+
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
153+
static_cast<size_t>(context_binary.length())));
154+
//LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
155+
if (!(*trt_engine_)) {
156+
return ort_api.CreateStatus(ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data");
157+
}
158+
159+
/*
160+
if (weight_stripped_engine_refit_) {
161+
const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
162+
std::string placeholder;
163+
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
164+
onnx_model_folder_path_,
165+
placeholder,
166+
make_secure_path_checks,
167+
onnx_model_bytestream_,
168+
onnx_model_bytestream_size_,
169+
onnx_external_data_bytestream_,
170+
onnx_external_data_bytestream_size_,
171+
(*trt_engine_).get(),
172+
false, // serialize refitted engine to disk
173+
detailed_build_log_);
174+
if (status != Status::OK()) {
175+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
176+
}
177+
}
178+
*/
179+
} else {
180+
// Get engine from cache file.
181+
node_attr = nullptr;
182+
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr));
183+
std::string cache_path = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->s();
184+
185+
/*
186+
// For security purpose, in the case of running context model, TRT EP won't allow
187+
// engine cache path to be the relative path like "../file_path" or the absolute path.
188+
// It only allows the engine cache to be in the same directory or sub directory of the context model.
189+
if (IsAbsolutePath(cache_path)) {
190+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path);
191+
}
192+
if (IsRelativePathToParentPath(cache_path)) {
193+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory.");
194+
}
195+
196+
// The engine cache and context model (current model) should be in the same directory
197+
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
198+
auto engine_cache_path = ctx_model_dir.append(cache_path);
199+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string();
200+
201+
// If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled
202+
if (!weight_stripped_engine_refit_) {
203+
weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path);
204+
}
205+
206+
// If the serialized refitted engine is present, use it directly without refitting the engine again
207+
if (weight_stripped_engine_refit_) {
208+
const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string());
209+
if (std::filesystem::exists(refitted_engine_cache_path)) {
210+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists.";
211+
engine_cache_path = refitted_engine_cache_path.string();
212+
weight_stripped_engine_refit_ = false;
213+
}
214+
}
215+
*/
216+
217+
std::filesystem::path engine_cache_path(cache_path);
218+
if (!std::filesystem::exists(engine_cache_path)) {
219+
std::string error_msg =
220+
"TensorRT EP can't find engine cache: " + engine_cache_path.string() +
221+
". Please make sure engine cache is in the same directory or sub-directory of context model.";
222+
return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str());
223+
}
224+
225+
std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in);
226+
engine_file.seekg(0, std::ios::end);
227+
size_t engine_size = engine_file.tellg();
228+
engine_file.seekg(0, std::ios::beg);
229+
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
230+
engine_file.read((char*)engine_buf.get(), engine_size);
231+
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
232+
if (!(*trt_engine_)) {
233+
std::string error_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string();
234+
return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str());
235+
}
236+
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();
237+
238+
/*
239+
if (weight_stripped_engine_refit_) {
240+
const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
241+
std::string weight_stripped_engine_cache = engine_cache_path.string();
242+
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
243+
onnx_model_folder_path_,
244+
weight_stripped_engine_cache,
245+
make_secure_path_checks,
246+
onnx_model_bytestream_,
247+
onnx_model_bytestream_size_,
248+
onnx_external_data_bytestream_,
249+
onnx_external_data_bytestream_size_,
250+
(*trt_engine_).get(),
251+
true, // serialize refitted engine to disk
252+
detailed_build_log_);
253+
if (status != Status::OK()) {
254+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
255+
}
256+
}
257+
*/
258+
}
259+
return nullptr;
260+
}
261+
114262
/*
115263
* Get the weight-refitted engine cache path from a weight-stripped engine cache path
116264
*

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,49 @@ class EPContextNodeHelper : public ApiPtrs {
3434
const OrtGraph* graph_ = nullptr;
3535
const OrtNode* fused_node_ = nullptr;
3636
};
37+
38+
class EPContextNodeReader : public ApiPtrs {
39+
public:
40+
EPContextNodeReader(TensorrtExecutionProvider& ep,
41+
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
42+
nvinfer1::IRuntime* trt_runtime,
43+
std::string ep_context_model_path,
44+
std::string compute_capability,
45+
bool weight_stripped_engine_refit,
46+
std::string onnx_model_folder_path,
47+
const void* onnx_model_bytestream,
48+
size_t onnx_model_bytestream_size,
49+
const void* onnx_external_data_bytestream,
50+
size_t onnx_external_data_bytestream_size,
51+
bool detailed_build_log)
52+
: ApiPtrs{static_cast<const ApiPtrs&>(ep)},
53+
trt_engine_(trt_engine),
54+
trt_runtime_(trt_runtime),
55+
ep_context_model_path_(ep_context_model_path),
56+
compute_capability_(compute_capability),
57+
weight_stripped_engine_refit_(weight_stripped_engine_refit),
58+
onnx_model_folder_path_(onnx_model_folder_path),
59+
onnx_model_bytestream_(onnx_model_bytestream),
60+
onnx_model_bytestream_size_(onnx_model_bytestream_size),
61+
onnx_external_data_bytestream_(onnx_external_data_bytestream),
62+
onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size),
63+
detailed_build_log_(detailed_build_log) {
64+
}
65+
66+
//bool ValidateEPCtxNode(const OrtGraph& graph);
67+
68+
OrtStatus* GetEpContextFromGraph(const OrtGraph& graph);
69+
70+
private:
71+
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
72+
nvinfer1::IRuntime* trt_runtime_;
73+
std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory
74+
std::string compute_capability_;
75+
bool weight_stripped_engine_refit_;
76+
std::string onnx_model_folder_path_;
77+
const void* onnx_model_bytestream_;
78+
size_t onnx_model_bytestream_size_;
79+
const void* onnx_external_data_bytestream_;
80+
size_t onnx_external_data_bytestream_size_;
81+
bool detailed_build_log_;
82+
}; // TRTCacheModelHandler

0 commit comments

Comments
 (0)