Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/templates/ze_loader_internal.h.mako
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace loader
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
ZEL_DRIVER_TYPE_NPU = 5, ///< The driver has NPU devices only
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff

} zel_driver_type_t;
Expand Down Expand Up @@ -114,6 +115,7 @@ namespace loader
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
void add_loader_version();
bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly);
void driverOrdering(driver_vector_t *drivers);
~context_t();
bool intercept_enabled = false;
bool debugTraceEnabled = false;
Expand Down
183 changes: 180 additions & 3 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* SPDX-License-Identifier: MIT
*
*/
#include "ze_loader_internal.h"
#include "ze_loader_utils.h"

#include "driver_discovery.h"
#include <iostream>
Expand Down Expand Up @@ -72,6 +72,179 @@ namespace loader
return a.driverType < b.driverType;
}

void context_t::driverOrdering(driver_vector_t *drivers) {
std::string orderStr = getenv_string("ZEL_DRIVERS_ORDER");
if (orderStr.empty()) {
return; // No ordering specified
}

std::vector<DriverOrderSpec> specs = parseDriverOrder(orderStr);

if (specs.empty()) {
if (debugTraceEnabled) {
std::string message = "driverOrdering: ZEL_DRIVERS_ORDER parsing failed or empty: " + orderStr;
debug_trace_message(message, "");
}
return;
}

if (debugTraceEnabled) {
std::string message = "driverOrdering:ZEL_DRIVERS_ORDER parsing successful: " + orderStr + ", specs count: " + std::to_string(specs.size());
debug_trace_message(message, "");
}

// Create a copy of the original driver vector for reference
driver_vector_t originalDrivers = *drivers;

driver_vector_t discreteGPUDrivers;
driver_vector_t integratedGPUDrivers;
driver_vector_t npuDrivers;
driver_vector_t gpuDrivers;

std::vector<uint32_t> discreteGPUIndices;
std::vector<uint32_t> integratedGPUIndices;
std::vector<uint32_t> npuIndices;
std::vector<uint32_t> gpuIndices;

// Group drivers by type and track their original indices
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
const auto& driver = originalDrivers[i];
switch (driver.driverType) {
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
discreteGPUDrivers.push_back(driver);
discreteGPUIndices.push_back(i);
break;
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
integratedGPUDrivers.push_back(driver);
integratedGPUIndices.push_back(i);
break;
case ZEL_DRIVER_TYPE_GPU:
gpuDrivers.push_back(driver);
gpuIndices.push_back(i);
break;
case ZEL_DRIVER_TYPE_NPU:
npuDrivers.push_back(driver);
npuIndices.push_back(i);
break;
case ZEL_DRIVER_TYPE_OTHER:
npuDrivers.push_back(driver);
npuIndices.push_back(i);
break;
case ZEL_DRIVER_TYPE_MIXED:
// Mixed drivers go to gpuDrivers
gpuDrivers.push_back(driver);
gpuIndices.push_back(i);
break;
default:
break;
}
}

// Create new ordered driver vector
driver_vector_t orderedDrivers;
std::set<uint32_t> usedGlobalIndices;
std::set<std::pair<zel_driver_type_t, uint32_t>> usedTypeIndices;

// Apply ordering specifications
for (const auto& spec : specs) {
switch (spec.type) {
case DriverOrderSpecType::BY_GLOBAL_INDEX:
if (spec.globalIndex < originalDrivers.size() &&
usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) {
orderedDrivers.push_back(originalDrivers[spec.globalIndex]);
usedGlobalIndices.insert(spec.globalIndex);
}
break;

case DriverOrderSpecType::BY_TYPE:
// Add all drivers of this type that haven't been used
{
std::vector<uint32_t>* typeIndices = nullptr;
switch (spec.driverType) {
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
typeIndices = &discreteGPUIndices;
break;
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
typeIndices = &integratedGPUIndices;
break;
case ZEL_DRIVER_TYPE_GPU:
typeIndices = &gpuIndices;
break;
case ZEL_DRIVER_TYPE_NPU:
case ZEL_DRIVER_TYPE_OTHER:
typeIndices = &npuIndices;
break;
default:
break;
}

if (typeIndices) {
for (uint32_t globalIdx : *typeIndices) {
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
orderedDrivers.push_back(originalDrivers[globalIdx]);
usedGlobalIndices.insert(globalIdx);
}
}
}
}
break;

case DriverOrderSpecType::BY_TYPE_AND_INDEX:
{
std::vector<uint32_t>* typeIndices = nullptr;
switch (spec.driverType) {
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
typeIndices = &discreteGPUIndices;
break;
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
typeIndices = &integratedGPUIndices;
break;
case ZEL_DRIVER_TYPE_GPU:
typeIndices = &gpuIndices;
break;
case ZEL_DRIVER_TYPE_NPU:
case ZEL_DRIVER_TYPE_OTHER:
typeIndices = &npuIndices;
break;
default:
break;
}

if (typeIndices && spec.typeIndex < typeIndices->size()) {
auto typeIndexPair = std::make_pair(spec.driverType, spec.typeIndex);
if (usedTypeIndices.find(typeIndexPair) == usedTypeIndices.end()) {
uint32_t globalIdx = (*typeIndices)[spec.typeIndex];
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
orderedDrivers.push_back(originalDrivers[globalIdx]);
usedGlobalIndices.insert(globalIdx);
usedTypeIndices.insert(typeIndexPair);
}
}
}
}
break;
}
}

// Add remaining drivers in their original order
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
if (usedGlobalIndices.find(i) == usedGlobalIndices.end()) {
orderedDrivers.push_back(originalDrivers[i]);
}
}

// Replace the original driver vector with the ordered one
*drivers = orderedDrivers;

if (debugTraceEnabled) {
std::string message = "driverOrdering: Drivers after ZEL_DRIVERS_ORDER:";
for (uint32_t i = 0; i < drivers->size(); ++i) {
message += "\n[" + std::to_string(i) + "] Driver Type: " + std::to_string((*drivers)[i].driverType) + " Driver Name: " + (*drivers)[i].name;
}
debug_trace_message(message, "");
}
}

bool context_t::driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly) {
ze_init_driver_type_desc_t permissiveDesc = {};
permissiveDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
Expand Down Expand Up @@ -246,6 +419,10 @@ namespace loader
}
debug_trace_message(message, "");
}

// Apply driver ordering based on ZEL_DRIVERS_ORDER environment variable
driverOrdering(drivers);

return true;
}

Expand Down Expand Up @@ -577,7 +754,7 @@ namespace loader
GET_FUNCTION_PTR(validationLayer, "zelLoaderGetVersion"));
zel_component_version_t compVersion;
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
{
{
compVersions.push_back(compVersion);
}
} else if (debugTraceEnabled) {
Expand All @@ -602,7 +779,7 @@ namespace loader
GET_FUNCTION_PTR(tracingLayer, "zelLoaderGetVersion"));
zel_component_version_t compVersion;
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
{
{
compVersions.push_back(compVersion);
}
} else if (debugTraceEnabled) {
Expand Down
2 changes: 2 additions & 0 deletions source/loader/ze_loader_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace loader
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
ZEL_DRIVER_TYPE_NPU = 5, ///< The driver has NPU devices only
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff

} zel_driver_type_t;
Expand Down Expand Up @@ -150,6 +151,7 @@ namespace loader
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
void add_loader_version();
bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly);
void driverOrdering(driver_vector_t *drivers);
~context_t();
bool intercept_enabled = false;
bool debugTraceEnabled = false;
Expand Down
108 changes: 108 additions & 0 deletions source/loader/ze_loader_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "ze_loader_internal.h"
#include <string>
#include <vector>
#include <map>
#include <set>
#include <sstream>
#include <cstdlib>
#include <algorithm>


namespace loader
{
// Helper function to map driver type string to enum
inline zel_driver_type_t stringToDriverType(const std::string& typeStr) {
if (typeStr == "DISCRETE_GPU_ONLY") {
return ZEL_DRIVER_TYPE_DISCRETE_GPU;
} else if (typeStr == "GPU") {
return ZEL_DRIVER_TYPE_GPU;
} else if (typeStr == "INTEGRATED_GPU_ONLY") {
return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
} else if (typeStr == "NPU") {
return ZEL_DRIVER_TYPE_NPU;
}
return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
}

// Helper function to trim whitespace
inline std::string trim(const std::string& str) {
const std::string whitespace = " \t\n\r\f\v";
size_t start = str.find_first_not_of(whitespace);
if (start == std::string::npos) return "";
size_t end = str.find_last_not_of(whitespace);
return str.substr(start, end - start + 1);
}

enum DriverOrderSpecType { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX };

// Structure to hold parsed ordering instructions
struct DriverOrderSpec {
DriverOrderSpecType type;
uint32_t globalIndex = 0;
zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
uint32_t typeIndex = 0;
};

// Parse ZEL_DRIVERS_ORDER environment variable
inline std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) {
std::vector<DriverOrderSpec> specs;

// Split by comma
std::vector<std::string> tokens;
std::stringstream ss(orderStr);
std::string token;

while (std::getline(ss, token, ',')) {
token = trim(token);
if (token.empty()) continue;

DriverOrderSpec spec;

// Check if it contains a colon (type:index format)
size_t colonPos = token.find(':');
if (colonPos != std::string::npos) {
// Format: <driver_type>:<driver_index>
std::string typeStr = trim(token.substr(0, colonPos));
std::string indexStr = trim(token.substr(colonPos + 1));

spec.driverType = stringToDriverType(typeStr);
if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
continue; // Invalid driver type, skip
}

try {
spec.typeIndex = std::stoul(indexStr);
spec.type = DriverOrderSpecType::BY_TYPE_AND_INDEX;
specs.push_back(spec);
} catch (const std::exception&) {
// Invalid index, skip
continue;
}
} else {
// Check if it's a pure number (global index) or driver type
try {
spec.globalIndex = std::stoul(token);
spec.type = DriverOrderSpecType::BY_GLOBAL_INDEX;
specs.push_back(spec);
} catch (const std::exception&) {
// Not a number, try as driver type
spec.driverType = stringToDriverType(token);
if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
spec.type = DriverOrderSpecType::BY_TYPE;
specs.push_back(spec);
}
}
}
}

return specs;
}
} // namespace loader
Loading
Loading