77
88#include " ep_utils.h"
99#include " onnx_ctx_model_helper.h"
10+ #include " onnx/onnx_pb.h"
1011
1112extern TensorrtLogger& GetTensorrtLogger (bool verbose_log);
1213
@@ -28,7 +29,7 @@ bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& o
2829
2930 const char * op_type = nullptr ;
3031 RETURN_IF_ERROR (ort_api.Node_GetOperatorType (node, &op_type));
31- if (node != nullptr && op_type == " EPContext" ) {
32+ if (node != nullptr && std::string ( op_type) == " EPContext" ) {
3233 return true ;
3334 }
3435 }
@@ -85,21 +86,21 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca
8586 std::array<OrtOpAttr*, 4 > attributes = {};
8687 DeferOrtRelease<OrtOpAttr> defer_release_attrs (attributes.data (), attributes.size (), ort_api.ReleaseOpAttr );
8788
88- RETURN_IF_ERROR (ort_api.CreateOpAttr (" embed_mode" , &embed_mode, 1 , ORT_OP_ATTR_INT, &attributes[0 ]));
89+ RETURN_IF_ERROR (ort_api.CreateOpAttr (" embed_mode" , &embed_mode, sizeof ( int64_t ) , ORT_OP_ATTR_INT, &attributes[0 ]));
8990
9091 std::string engine_data_str = " " ;
9192 if (embed_mode) {
9293 if (size > 0 ) {
9394 engine_data_str.assign (engine_data, size);
9495 }
9596 RETURN_IF_ERROR (
96- ort_api.CreateOpAttr (" ep_cache_context" , engine_data_str.c_str (), 1 , ORT_OP_ATTR_STRING, &attributes[1 ]));
97+ ort_api.CreateOpAttr (" ep_cache_context" , engine_data_str.c_str (), engine_data_str. size () , ORT_OP_ATTR_STRING, &attributes[1 ]));
9798 } else {
98- RETURN_IF_ERROR (ort_api.CreateOpAttr (" ep_cache_context" , engine_cache_path.c_str (), 1 , ORT_OP_ATTR_STRING, &attributes[1 ]));
99+ RETURN_IF_ERROR (ort_api.CreateOpAttr (" ep_cache_context" , engine_cache_path.c_str (), engine_cache_path. size () , ORT_OP_ATTR_STRING, &attributes[1 ]));
99100 }
100101
101102
102- ort_api.CreateOpAttr (" hardware_architecture" , compute_capability.c_str (), 1 , ORT_OP_ATTR_STRING, &attributes[2 ]);
103+ ort_api.CreateOpAttr (" hardware_architecture" , compute_capability.c_str (), compute_capability. size () , ORT_OP_ATTR_STRING, &attributes[2 ]);
103104 ort_api.CreateOpAttr (" onnx_model_filename" , std::filesystem::path (onnx_model_path).filename ().string ().c_str (), 1 ,
104105 ORT_OP_ATTR_STRING, &attributes[3 ]);
105106
@@ -111,6 +112,153 @@ OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_ca
111112 return nullptr ;
112113}
113114
115+ OrtStatus* EPContextNodeReader::GetEpContextFromGraph (const OrtGraph& graph) {
116+ /*
117+ if (!ValidateEPCtxNode(graph)) {
118+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
119+ }
120+ */
121+
122+ size_t num_nodes = 0 ;
123+ RETURN_IF_ERROR (ort_api.Graph_GetNumNodes (&graph, &num_nodes));
124+
125+ std::vector<const OrtNode*> nodes (num_nodes);
126+ RETURN_IF_ERROR (ort_api.Graph_GetNodes (&graph, nodes.data (), nodes.size ()));
127+
128+ auto node = nodes[0 ];
129+
130+ size_t num_node_attributes = 0 ;
131+ RETURN_IF_ERROR (ort_api.Node_GetNumAttributes (node, &num_node_attributes));
132+
133+ /*
134+ std::vector<const OrtOpAttr*> node_attributes(num_node_attributes);
135+ RETURN_IF_ERROR(ort_api.Node_GetAttributes(node, node_attributes.data(), node_attributes.size()));
136+ */
137+
138+ const OrtOpAttr* node_attr = nullptr ;
139+ RETURN_IF_ERROR (ort_api.Node_GetAttributeByName (node, " embed_mode" , &node_attr));
140+ const int64_t embed_mode = reinterpret_cast <const ONNX_NAMESPACE::AttributeProto*>(node_attr)->i ();
141+
142+ // Only make path checks if model not provided as byte buffer
143+ // bool make_secure_path_checks = !GetModelPath(graph_viewer).empty();
144+ bool make_secure_path_checks = false ;
145+
146+ if (embed_mode) {
147+ // Get engine from byte stream.
148+ node_attr = nullptr ;
149+ RETURN_IF_ERROR (ort_api.Node_GetAttributeByName (node, " ep_cache_context" , &node_attr));
150+ const std::string& context_binary = reinterpret_cast <const ONNX_NAMESPACE::AttributeProto*>(node_attr)->s ();
151+
152+ *(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine (const_cast <char *>(context_binary.c_str ()),
153+ static_cast <size_t >(context_binary.length ())));
154+ // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
155+ if (!(*trt_engine_)) {
156+ return ort_api.CreateStatus (ORT_EP_FAIL, " TensorRT EP could not deserialize engine from binary data" );
157+ }
158+
159+ /*
160+ if (weight_stripped_engine_refit_) {
161+ const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
162+ std::string placeholder;
163+ auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
164+ onnx_model_folder_path_,
165+ placeholder,
166+ make_secure_path_checks,
167+ onnx_model_bytestream_,
168+ onnx_model_bytestream_size_,
169+ onnx_external_data_bytestream_,
170+ onnx_external_data_bytestream_size_,
171+ (*trt_engine_).get(),
172+ false, // serialize refitted engine to disk
173+ detailed_build_log_);
174+ if (status != Status::OK()) {
175+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
176+ }
177+ }
178+ */
179+ } else {
180+ // Get engine from cache file.
181+ node_attr = nullptr ;
182+ RETURN_IF_ERROR (ort_api.Node_GetAttributeByName (node, " ep_cache_context" , &node_attr));
183+ std::string cache_path = reinterpret_cast <const ONNX_NAMESPACE::AttributeProto*>(node_attr)->s ();
184+
185+ /*
186+ // For security purpose, in the case of running context model, TRT EP won't allow
187+ // engine cache path to be the relative path like "../file_path" or the absolute path.
188+ // It only allows the engine cache to be in the same directory or sub directory of the context model.
189+ if (IsAbsolutePath(cache_path)) {
190+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path);
191+ }
192+ if (IsRelativePathToParentPath(cache_path)) {
193+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory.");
194+ }
195+
196+ // The engine cache and context model (current model) should be in the same directory
197+ std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
198+ auto engine_cache_path = ctx_model_dir.append(cache_path);
199+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string();
200+
201+ // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled
202+ if (!weight_stripped_engine_refit_) {
203+ weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path);
204+ }
205+
206+ // If the serialized refitted engine is present, use it directly without refitting the engine again
207+ if (weight_stripped_engine_refit_) {
208+ const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string());
209+ if (std::filesystem::exists(refitted_engine_cache_path)) {
210+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists.";
211+ engine_cache_path = refitted_engine_cache_path.string();
212+ weight_stripped_engine_refit_ = false;
213+ }
214+ }
215+ */
216+
217+ std::filesystem::path engine_cache_path (cache_path);
218+ if (!std::filesystem::exists (engine_cache_path)) {
219+ std::string error_msg =
220+ " TensorRT EP can't find engine cache: " + engine_cache_path.string () +
221+ " . Please make sure engine cache is in the same directory or sub-directory of context model." ;
222+ return ort_api.CreateStatus (ORT_EP_FAIL, error_msg.c_str ());
223+ }
224+
225+ std::ifstream engine_file (engine_cache_path.string (), std::ios::binary | std::ios::in);
226+ engine_file.seekg (0 , std::ios::end);
227+ size_t engine_size = engine_file.tellg ();
228+ engine_file.seekg (0 , std::ios::beg);
229+ std::unique_ptr<char []> engine_buf{new char [engine_size]};
230+ engine_file.read ((char *)engine_buf.get (), engine_size);
231+ *(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine (engine_buf.get (), engine_size));
232+ if (!(*trt_engine_)) {
233+ std::string error_msg = " TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string ();
234+ return ort_api.CreateStatus (ORT_EP_FAIL, error_msg.c_str ());
235+ }
236+ // LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();
237+
238+ /*
239+ if (weight_stripped_engine_refit_) {
240+ const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
241+ std::string weight_stripped_engine_cache = engine_cache_path.string();
242+ auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
243+ onnx_model_folder_path_,
244+ weight_stripped_engine_cache,
245+ make_secure_path_checks,
246+ onnx_model_bytestream_,
247+ onnx_model_bytestream_size_,
248+ onnx_external_data_bytestream_,
249+ onnx_external_data_bytestream_size_,
250+ (*trt_engine_).get(),
251+ true, // serialize refitted engine to disk
252+ detailed_build_log_);
253+ if (status != Status::OK()) {
254+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
255+ }
256+ }
257+ */
258+ }
259+ return nullptr ;
260+ }
261+
114262/*
115263 * Get the weight-refitted engine cache path from a weight-stripped engine cache path
116264 *
0 commit comments