Skip to content

Commit 6d04a2e

Browse files
authored
feat: Enable EpContext OVIR Encapsulation (#704)
* feat: Enable EpContext OVIR Encapsulation * fix: refactor EpCtx OVIR parsing logic to use ep.context_file_path * fix: Fix logic for parsing model_file_path * fix: enable EPCtx OVIR encapsulation compiled blob caching * fix: fix merge conflicts * fix: fix bugs
1 parent cbef617 commit 6d04a2e

File tree

9 files changed

+169
-12
lines changed

9 files changed

+169
-12
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ BackendManager::BackendManager(SessionContext& session_context,
4343
session_context_(session_context),
4444
shared_context_{shared_context} {
4545
subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph);
46+
// If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute
47+
subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph,
48+
session_context_.onnx_model_path_name.filename().replace_extension("xml").string());
4649

4750
subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) {
4851
// return empty if graph has no inputs or if types are not one of FP32/FP16
@@ -192,9 +195,10 @@ BackendManager::BackendManager(SessionContext& session_context,
192195
}
193196
}
194197
}
195-
if (session_context_.so_context_enable && !subgraph_context_.is_ep_ctx_graph) {
198+
if (session_context_.so_context_enable &&
199+
(subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) {
196200
auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph);
197-
if ((!status.IsOK())) {
201+
if (!status.IsOK()) {
198202
ORT_THROW(status);
199203
}
200204
}

onnxruntime/core/providers/openvino/backend_utils.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,33 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map)
400400
metadata_map.clear();
401401
}
402402

403+
bool IsModelStreamXML(std::istream& model_stream) {
404+
std::streampos originalPos = model_stream.tellg();
405+
406+
// first, get the total size of model_stream in bytes
407+
model_stream.seekg(0, std::ios::end);
408+
auto end_pos = model_stream.tellg();
409+
// Restore the stream position
410+
model_stream.seekg(originalPos);
411+
auto total_size = end_pos - originalPos;
412+
413+
// Choose 32 bytes to hold content of:
414+
// '<?xml version-"1.0"?> <net '
415+
const std::streamsize header_check_len = 32;
416+
ORT_ENFORCE(total_size > header_check_len);
417+
418+
// read 32 bytes into header
419+
std::string header(header_check_len, '\0');
420+
model_stream.read(&header[0], header_check_len);
421+
// Clear any read errors
422+
model_stream.clear();
423+
// Restore the stream position
424+
model_stream.seekg(originalPos);
425+
426+
// return true if the header starts with '<?xml' and also includes '<net '
427+
return ((header.rfind("<?xml", 0) == 0) && (header.find("<net ") != std::string::npos));
428+
}
429+
403430
} // namespace backend_utils
404431
} // namespace openvino_ep
405432
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
107107

108108
void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName);
109109

110+
bool IsModelStreamXML(std::istream& model_stream);
111+
110112
} // namespace backend_utils
111113
} // namespace openvino_ep
112114
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,38 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
7272
!session_context_.so_disable_cpu_ep_fallback &&
7373
!subgraph_context_.is_ep_ctx_graph);
7474
if (subgraph_context_.is_ep_ctx_graph) {
75-
// If the blob is held in an EPContext node, then skip FE+Compile
76-
// and directly move on to creating a backend with the executable blob
77-
exe_network_ = OVCore::Get()->ImportModel(*model_stream,
78-
hw_target,
79-
device_config,
80-
subgraph_context_.subgraph_name);
75+
if (subgraph_context_.is_ep_ctx_ovir_encapsulated) {
76+
// model_file_path will use so_context_file_path if the onnx_model_path_name is not available,
77+
// especially in case of CreateSessionFormArray() where user must explicitly
78+
// specify absolute path for so_context_file_path.
79+
auto model_file_path = [this]() {
80+
if (!session_context_.onnx_model_path_name.empty() &&
81+
std::filesystem::exists(session_context_.onnx_model_path_name)) return session_context_.onnx_model_path_name;
82+
83+
ORT_ENFORCE(!session_context_.so_context_file_path.empty() &&
84+
std::filesystem::path(session_context_.so_context_file_path).is_absolute() &&
85+
std::filesystem::exists(session_context_.so_context_file_path), log_tag +
86+
"Context file path must be non-empty & absolute, when using CreateSessionFormArray() API explicitly."
87+
" Please set a valid absolute path for ep.context_file_path in session options.");
88+
// Return absolute context file path as input to ImportEPCtxOVIREncapsulation() function.
89+
return session_context_.so_context_file_path;
90+
91+
};
92+
// If the EPContext node with OVIR Encapsulation, then create
93+
// an executable network from EP_CACHE_CONTEXT using read_model() & compile_model()
94+
exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream,
95+
hw_target,
96+
device_config,
97+
enable_causallm,
98+
model_file_path());
99+
} else {
100+
// If the blob is held in an EPContext node, then skip FE+Compile
101+
// and directly move on to creating a backend with the executable blob
102+
exe_network_ = OVCore::Get()->ImportModel(*model_stream,
103+
hw_target,
104+
device_config,
105+
subgraph_context_.subgraph_name);
106+
}
81107
model_stream.reset(); // Delete stream after it is no longer needed
82108
} else if (!session_context_.has_external_weights &&
83109
!subgraph_context_.has_dynamic_input_shape &&

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ struct SubGraphContext {
137137
string_index_map_t output_names;
138138
std::string model_precision;
139139
bool is_ep_ctx_graph = false;
140+
bool is_ep_ctx_ovir_encapsulated = false;
140141
};
141142

142143
} // namespace openvino_ep

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88

99
#include "core/providers/openvino/onnx_ctx_model_helper.h"
10+
#include "core/providers/openvino/backend_utils.h"
1011

1112
namespace onnxruntime {
1213
namespace openvino_ep {
@@ -123,6 +124,16 @@ std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const std::filesy
123124
ORT_ENFORCE(std::filesystem::exists(blob_filepath), "Blob file not found: ", blob_filepath.string());
124125
result.reset((std::istream*)new std::ifstream(blob_filepath, std::ios_base::binary | std::ios_base::in));
125126
}
127+
128+
bool isXML = backend_utils::IsModelStreamXML(*result);
129+
if (!isXML) {
130+
// If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was
131+
// exported with must match the version that is currently running.
132+
ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_),
133+
"EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() +
134+
", but OpenVINO SDK version currently in use is " + openvino_sdk_version_);
135+
}
136+
126137
LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node";
127138
return result;
128139
}
@@ -142,7 +153,6 @@ bool EPCtxHandler::CheckForOVEPCtxNode(const Node& node) const {
142153
if (node.OpType() == EPCONTEXT_OP) {
143154
auto& attrs = node.GetAttributes();
144155
bool result = (attrs.count(SOURCE) == 1) && (attrs.at(SOURCE).s() == kOpenVINOExecutionProvider);
145-
result &= (attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_);
146156
result &= attrs.count(EMBED_MODE) == 1;
147157
result &= attrs.count(EP_CACHE_CONTEXT) == 1;
148158
return result;
@@ -155,5 +165,32 @@ InlinedVector<const Node*> EPCtxHandler::GetEPCtxNodes() const {
155165
return InlinedVector<const Node*>(epctx_nodes.begin(), epctx_nodes.end());
156166
}
157167

168+
// Check if graph's only node is EPContext & EP_CACHE_CONTEXT attribute has target extension.
169+
// @param graph_viewer: The graph to inspect.
170+
// @param target_attr_extn: The string to search for in the EP_CACHE_CONTEXT attribute.
171+
// @return true if the node exists, is of the correct type, and the attribute contains the extension; false otherwise.
172+
bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const {
173+
// Only check if the graph has exactly one node
174+
if (graph_viewer.NumberOfNodes() != 1) {
175+
return false;
176+
}
177+
// Get the first node in topological order
178+
auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin();
179+
const Node* node = graph_viewer.GetNode(first_index);
180+
if (!node) {
181+
return false;
182+
}
183+
// Check OpType and required attributes
184+
if (node->OpType() != EPCONTEXT_OP) {
185+
return false;
186+
}
187+
const auto& attrs = node->GetAttributes();
188+
auto it = attrs.find(EP_CACHE_CONTEXT);
189+
if (it != attrs.end()) {
190+
return it->second().s().find(target_attr_extn) != std::string::npos;
191+
}
192+
return false;
193+
}
194+
158195
} // namespace openvino_ep
159196
} // namespace onnxruntime

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class EPCtxHandler {
3333
std::string&& model_blob_str) const;
3434
std::unique_ptr<std::istream> GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const;
3535
InlinedVector<const Node*> GetEPCtxNodes() const;
36+
bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const;
3637

3738
private:
3839
const std::string openvino_sdk_version_;

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ void printDebugInfo(const ov::CompiledModel& obj) {
4747
continue;
4848
OPENVINO_SUPPRESS_DEPRECATED_END
4949
std::cout << " " << item2.first << ": " << item2.second.as<std::string>() << std::endl;
50-
}
5150
}
5251
} else {
5352
std::cout << " " << cfg << ": " << prop.as<std::string>() << std::endl;
@@ -101,10 +100,10 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
101100
LogBasicModelInfo(model);
102101
}
103102

104-
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
105103
bool model_status = IsStateful(model);
106104
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
107105
if (!model_status) {
106+
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
108107
PatchStatefulDecoder(model);
109108
}
110109

@@ -198,15 +197,69 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
198197
return OvExceptionBoundary([&]() {
199198
ov::CompiledModel obj;
200199
obj = core.import_model(model_stream, hw_target, device_config);
200+
OVExeNetwork exe(obj, hw_target);
201201
#ifndef NDEBUG
202202
printDebugInfo(exe.Get());
203203
#endif
204-
OVExeNetwork exe(obj, hw_target);
205204
return exe;
206205
},
207206
"Exception while Loading Network for graph {}", name);
208207
}
209208

209+
OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream,
210+
std::string& hw_target,
211+
const ov::AnyMap& device_config,
212+
bool enable_causallm,
213+
std::filesystem::path model_file_path) {
214+
return OvExceptionBoundary([&]() {
215+
OVExeNetwork exe;
216+
217+
bool isXML = backend_utils::IsModelStreamXML(model_stream);
218+
219+
// Helper function to check if file exists and is readable
220+
const auto check_file_access = [&model_file_path](const std::filesystem::path& path) {
221+
try {
222+
if (!std::filesystem::exists(path) || std::filesystem::is_empty(path)) {
223+
ORT_THROW(log_tag + "Required file missing or empty: " + path.string());
224+
}
225+
std::ifstream file(path);
226+
if (!file) {
227+
ORT_THROW(log_tag + "Required file not readable: " + path.string());
228+
}
229+
} catch (const std::exception& e) {
230+
ORT_THROW(log_tag + "Exception while checking file access for: " + path.string() + " - " + e.what());
231+
}
232+
};
233+
234+
if (isXML) {
235+
// If the model is XML, we need to load it with the XML content in read_model()
236+
// where weights from bin file is directly consumed
237+
auto xml_file_path = model_file_path.parent_path() / (model_file_path.stem().string() + ".xml");
238+
239+
check_file_access(xml_file_path);
240+
241+
LOGS_DEFAULT(INFO) << log_tag << "Reading OVIR from XML file path: " << xml_file_path.string();
242+
243+
// Load the model explicitly with XML contents
244+
std::shared_ptr<ov::Model> model = core.read_model(xml_file_path.string());
245+
246+
if (enable_causallm) {
247+
exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
248+
} else {
249+
auto obj = core.compile_model(model, hw_target, device_config);
250+
exe = OVExeNetwork(obj, hw_target);
251+
}
252+
}
253+
254+
#ifndef NDEBUG
255+
printDebugInfo(exe.Get());
256+
#endif
257+
return exe;
258+
},
259+
"Exception while Loading Network from OVIR model file: {}", model_file_path.string());
260+
}
261+
262+
210263
void OVCore::SetCache(const std::string& cache_dir_path) {
211264
core.set_property(ov::cache_dir(cache_dir_path));
212265
}

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ struct OVCore : WeakSingleton<OVCore> {
8686
std::string hw_target,
8787
const ov::AnyMap& device_config,
8888
std::string name);
89+
OVExeNetwork ImportEPCtxOVIREncapsulation(std::istream& model_stream,
90+
std::string& hw_target,
91+
const ov::AnyMap& device_config,
92+
bool enable_causallm,
93+
std::filesystem::path model_file_path);
94+
8995
std::vector<std::string> GetAvailableDevices() const;
9096
std::vector<std::string> GetAvailableDevices(const std::string& device_type) const;
9197
void SetCache(const std::string& cache_dir_path);

0 commit comments

Comments
 (0)