Skip to content

Commit cc98dd8

Browse files
RyanMetcalfeInt8ankitm3k
authored andcommitted
ovep: Support multiple devices (i.e. AUTO) passed to CreateIExecutionProvider
1 parent 271cea0 commit cc98dd8

File tree

3 files changed

+61
-16
lines changed

3 files changed

+61
-16
lines changed

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -433,45 +433,85 @@ struct OpenVINO_Provider : Provider {
433433
return std::make_shared<OpenVINOProviderFactory>(pi, SharedContext::Get());
434434
}
435435

436-
Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/,
436+
Status CreateIExecutionProvider(const OrtHardwareDevice* const* devices,
437437
const OrtKeyValuePairs* const* ep_metadata,
438438
size_t num_devices,
439439
ProviderOptions& provider_options,
440440
const OrtSessionOptions& session_options,
441441
const OrtLogger& logger,
442442
std::unique_ptr<IExecutionProvider>& ep) override {
443-
if (num_devices != 1) {
444-
return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "OpenVINO EP only supports one device.");
443+
// Check if no devices are provided
444+
if (num_devices == 0) {
445+
return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "No devices provided to CreateEp");
445446
}
446447

447448
// Block setting certain provider options via AppendExecutionProvider_V2
449+
// TODO: Expand this out and give better guidance for keys that should now flow through load_config.
448450
const std::unordered_set<std::string> blocked_provider_keys = {
449451
"device_type", "device_id", "device_luid", "cache_dir", "precision",
450452
"context", "num_of_threads", "model_priority", "num_streams",
451-
"enable_opencl_throttling", "enable_qdq_optimizer", "disable_dynamic_shapes"
452-
};
453+
"enable_opencl_throttling", "enable_qdq_optimizer", "disable_dynamic_shapes"};
453454

454455
for (const auto& key : blocked_provider_keys) {
455456
if (provider_options.find(key) != provider_options.end()) {
456-
return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT,
457-
"OpenVINO EP: Option '" + key + "' cannot be set explicitly when using AppendExecutionProvider_V2.");
457+
return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT,
458+
"OpenVINO EP: Option '" + key + "' cannot be set explicitly when using AppendExecutionProvider_V2.");
458459
}
459460
}
460461

461-
// Extract device type from EP metadata
462-
const auto& device_meta_data = ep_metadata[0];
463-
auto it = device_meta_data->Entries().find("ov_device");
464-
if (it == device_meta_data->Entries().end()) {
465-
return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, "OpenVINO EP device metadata not found.");
462+
const char* ov_device_key = "ov_device";
463+
const char* ov_meta_device_key = "ov_meta_device";
464+
465+
// Create a unique list of ov_devices that were passed in.
466+
std::unordered_set<std::string_view> unique_ov_devices;
467+
std::vector<std::string_view> ordered_unique_ov_devices;
468+
for (size_t i = 0; i < num_devices; ++i) {
469+
const auto& device_meta_data = ep_metadata[i];
470+
auto ov_device_it = device_meta_data->Entries().find(ov_device_key);
471+
if (ov_device_it == device_meta_data->Entries().end()) {
472+
return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, "OpenVINO EP device metadata not found.");
473+
}
474+
auto &ov_device = ov_device_it->second;
475+
476+
// Add to ordered_unique only if not already present
477+
if (unique_ov_devices.insert(ov_device).second) {
478+
ordered_unique_ov_devices.push_back(ov_device);
479+
}
480+
}
481+
482+
std::string ov_meta_device_type = "NONE";
483+
{
484+
auto ov_meta_device_it = ep_metadata[0]->Entries().find(ov_meta_device_key);
485+
if (ov_meta_device_it != ep_metadata[0]->Entries().end()) {
486+
ov_meta_device_type = ov_meta_device_it->second;
487+
}
466488
}
467489

468-
std::string metadata_device_type = it->second;
490+
bool is_meta_device_factory = (ov_meta_device_type != "NONE");
469491

470-
// If user didn't specify device_type, use the one from metadata
471-
if (provider_options.find("device_type") == provider_options.end()) {
472-
provider_options["device_type"] = metadata_device_type;
492+
if (ordered_unique_ov_devices.size() > 1 && !is_meta_device_factory) {
493+
LOGS_DEFAULT(WARNING) << "[OpenVINO EP] Multiple devices were specified that are not OpenVINO meta devices. Using first ov_device only: " << ordered_unique_ov_devices.at(0);
494+
ordered_unique_ov_devices.resize(1); // Use only the first device if not a meta device factory
473495
}
474496

497+
std::string ov_device_string;
498+
if (is_meta_device_factory) {
499+
// Build up a meta device string based on the devices that are passed in. E.g. AUTO:NPU,GPU.0,CPU
500+
ov_device_string = ov_meta_device_type;
501+
ov_device_string += ":";
502+
}
503+
504+
bool prepend_comma = false;
505+
for (const auto& ov_device : ordered_unique_ov_devices) {
506+
if (prepend_comma) {
507+
ov_device_string += ",";
508+
}
509+
ov_device_string += ov_device;
510+
prepend_comma = true;
511+
}
512+
513+
provider_options["device_type"] = ov_device_string;
514+
475515
// Parse provider info with the device type
476516
ProviderInfo pi;
477517
const auto& config_options = session_options.GetConfigOptions();

onnxruntime/core/providers/openvino/ov_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice*
124124
ort_api.CreateKeyValuePairs(&ep_metadata);
125125
ort_api.AddKeyValuePair(ep_metadata, ov_device_key_, matched_device->c_str());
126126

127+
if (IsMetaDeviceFactory()) {
128+
ort_api.AddKeyValuePair(ep_metadata, ov_meta_device_key_, device_type_.c_str());
129+
}
130+
127131
// Create EP device
128132
auto* status = ort_api.GetEpApi()->CreateEpDevice(this, &device, ep_metadata, ep_options,
129133
&ep_devices[num_ep_devices++]);

onnxruntime/core/providers/openvino/ov_factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class OpenVINOEpPluginFactory : public OrtEpFactory, public ApiPtrs {
104104
static constexpr const char* vendor_ = "Intel";
105105
static constexpr uint32_t vendor_id_{0x8086}; // Intel's PCI vendor ID
106106
static constexpr const char* ov_device_key_ = "ov_device";
107+
static constexpr const char* ov_meta_device_key_ = "ov_meta_device";
107108
static constexpr const char* provider_name_ = "OpenVINOExecutionProvider";
108109

109110
private:

0 commit comments

Comments
 (0)