Skip to content

Commit d76cae1

Browse files
committed
Support for ZEL_DRIVERS_ORDER to order based off user input
Signed-off-by: Neil R. Spruit <[email protected]>
1 parent f874a09 commit d76cae1

File tree

6 files changed

+2099
-5
lines changed

6 files changed

+2099
-5
lines changed

source/loader/ze_loader.cpp

Lines changed: 272 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99

1010
#include "driver_discovery.h"
1111
#include <iostream>
12+
#include <string>
13+
#include <vector>
14+
#include <map>
15+
#include <set>
16+
#include <sstream>
17+
#include <cstdlib>
18+
#include <algorithm>
1219

1320
#ifdef __linux__
1421
#include <unistd.h>
@@ -72,6 +79,265 @@ namespace loader
7279
return a.driverType < b.driverType;
7380
}
7481

82+
// Helper function to map driver type string to enum
83+
zel_driver_type_t stringToDriverType(const std::string& typeStr) {
84+
if (typeStr == "DISCRETE_GPU_ONLY") {
85+
return ZEL_DRIVER_TYPE_DISCRETE_GPU;
86+
} else if (typeStr == "GPU") {
87+
return ZEL_DRIVER_TYPE_GPU;
88+
} else if (typeStr == "INTEGRATED_GPU_ONLY") {
89+
return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
90+
} else if (typeStr == "NPU") {
91+
return ZEL_DRIVER_TYPE_NPU;
92+
}
93+
return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
94+
}
95+
96+
// Helper function to trim whitespace
97+
std::string trim(const std::string& str) {
98+
const std::string whitespace = " \t\n\r\f\v";
99+
size_t start = str.find_first_not_of(whitespace);
100+
if (start == std::string::npos) return "";
101+
size_t end = str.find_last_not_of(whitespace);
102+
return str.substr(start, end - start + 1);
103+
}
104+
105+
// Structure to hold parsed ordering instructions
106+
struct DriverOrderSpec {
107+
enum Type { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX } type;
108+
uint32_t globalIndex = 0;
109+
zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
110+
uint32_t typeIndex = 0;
111+
};
112+
113+
// Parse ZEL_DRIVERS_ORDER environment variable
114+
std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) {
115+
std::vector<DriverOrderSpec> specs;
116+
117+
// Split by comma
118+
std::vector<std::string> tokens;
119+
std::stringstream ss(orderStr);
120+
std::string token;
121+
122+
while (std::getline(ss, token, ',')) {
123+
token = trim(token);
124+
if (token.empty()) continue;
125+
126+
DriverOrderSpec spec;
127+
128+
// Check if it contains a colon (type:index format)
129+
size_t colonPos = token.find(':');
130+
if (colonPos != std::string::npos) {
131+
// Format: <driver_type>:<driver_index>
132+
std::string typeStr = trim(token.substr(0, colonPos));
133+
std::string indexStr = trim(token.substr(colonPos + 1));
134+
135+
spec.driverType = stringToDriverType(typeStr);
136+
if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
137+
continue; // Invalid driver type, skip
138+
}
139+
140+
try {
141+
spec.typeIndex = std::stoul(indexStr);
142+
spec.type = DriverOrderSpec::BY_TYPE_AND_INDEX;
143+
specs.push_back(spec);
144+
} catch (const std::exception&) {
145+
// Invalid index, skip
146+
continue;
147+
}
148+
} else {
149+
// Check if it's a pure number (global index) or driver type
150+
try {
151+
spec.globalIndex = std::stoul(token);
152+
spec.type = DriverOrderSpec::BY_GLOBAL_INDEX;
153+
specs.push_back(spec);
154+
} catch (const std::exception&) {
155+
// Not a number, try as driver type
156+
spec.driverType = stringToDriverType(token);
157+
if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
158+
spec.type = DriverOrderSpec::BY_TYPE;
159+
specs.push_back(spec);
160+
}
161+
}
162+
}
163+
}
164+
165+
return specs;
166+
}
167+
168+
void context_t::driverOrdering(driver_vector_t *drivers) {
169+
std::string orderStr = getenv_string("ZEL_DRIVERS_ORDER");
170+
if (orderStr.empty()) {
171+
return; // No ordering specified
172+
}
173+
174+
std::vector<DriverOrderSpec> specs = parseDriverOrder(orderStr);
175+
176+
if (specs.empty()) {
177+
if (debugTraceEnabled) {
178+
std::string message = "driverOrdering: ZEL_DRIVERS_ORDER parsing failed or empty: " + orderStr;
179+
debug_trace_message(message, "");
180+
}
181+
return;
182+
}
183+
184+
if (debugTraceEnabled) {
185+
std::string message = "driverOrdering:ZEL_DRIVERS_ORDER parsing successful: " + orderStr + ", specs count: " + std::to_string(specs.size());
186+
debug_trace_message(message, "");
187+
}
188+
189+
// Create a copy of the original driver vector for reference
190+
driver_vector_t originalDrivers = *drivers;
191+
192+
driver_vector_t discreteGPUDrivers;
193+
driver_vector_t integratedGPUDrivers;
194+
driver_vector_t npuDrivers;
195+
driver_vector_t gpuDrivers;
196+
197+
std::vector<uint32_t> discreteGPUIndices;
198+
std::vector<uint32_t> integratedGPUIndices;
199+
std::vector<uint32_t> npuIndices;
200+
std::vector<uint32_t> gpuIndices;
201+
202+
// Group drivers by type and track their original indices
203+
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
204+
const auto& driver = originalDrivers[i];
205+
switch (driver.driverType) {
206+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
207+
discreteGPUDrivers.push_back(driver);
208+
discreteGPUIndices.push_back(i);
209+
break;
210+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
211+
integratedGPUDrivers.push_back(driver);
212+
integratedGPUIndices.push_back(i);
213+
break;
214+
case ZEL_DRIVER_TYPE_GPU:
215+
gpuDrivers.push_back(driver);
216+
gpuIndices.push_back(i);
217+
break;
218+
case ZEL_DRIVER_TYPE_NPU:
219+
npuDrivers.push_back(driver);
220+
npuIndices.push_back(i);
221+
break;
222+
case ZEL_DRIVER_TYPE_OTHER:
223+
npuDrivers.push_back(driver);
224+
npuIndices.push_back(i);
225+
break;
226+
case ZEL_DRIVER_TYPE_MIXED:
227+
// Mixed drivers go to gpuDrivers
228+
gpuDrivers.push_back(driver);
229+
gpuIndices.push_back(i);
230+
break;
231+
default:
232+
break;
233+
}
234+
}
235+
236+
// Create new ordered driver vector
237+
driver_vector_t orderedDrivers;
238+
std::set<uint32_t> usedGlobalIndices;
239+
std::set<std::pair<zel_driver_type_t, uint32_t>> usedTypeIndices;
240+
241+
// Apply ordering specifications
242+
for (const auto& spec : specs) {
243+
switch (spec.type) {
244+
case DriverOrderSpec::BY_GLOBAL_INDEX:
245+
if (spec.globalIndex < originalDrivers.size() &&
246+
usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) {
247+
orderedDrivers.push_back(originalDrivers[spec.globalIndex]);
248+
usedGlobalIndices.insert(spec.globalIndex);
249+
}
250+
break;
251+
252+
case DriverOrderSpec::BY_TYPE:
253+
// Add all drivers of this type that haven't been used
254+
{
255+
std::vector<uint32_t>* typeIndices = nullptr;
256+
switch (spec.driverType) {
257+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
258+
typeIndices = &discreteGPUIndices;
259+
break;
260+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
261+
typeIndices = &integratedGPUIndices;
262+
break;
263+
case ZEL_DRIVER_TYPE_GPU:
264+
typeIndices = &gpuIndices;
265+
break;
266+
case ZEL_DRIVER_TYPE_NPU:
267+
case ZEL_DRIVER_TYPE_OTHER:
268+
typeIndices = &npuIndices;
269+
break;
270+
default:
271+
break;
272+
}
273+
274+
if (typeIndices) {
275+
for (uint32_t globalIdx : *typeIndices) {
276+
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
277+
orderedDrivers.push_back(originalDrivers[globalIdx]);
278+
usedGlobalIndices.insert(globalIdx);
279+
}
280+
}
281+
}
282+
}
283+
break;
284+
285+
case DriverOrderSpec::BY_TYPE_AND_INDEX:
286+
{
287+
std::vector<uint32_t>* typeIndices = nullptr;
288+
switch (spec.driverType) {
289+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
290+
typeIndices = &discreteGPUIndices;
291+
break;
292+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
293+
typeIndices = &integratedGPUIndices;
294+
break;
295+
case ZEL_DRIVER_TYPE_GPU:
296+
typeIndices = &gpuIndices;
297+
break;
298+
case ZEL_DRIVER_TYPE_NPU:
299+
case ZEL_DRIVER_TYPE_OTHER:
300+
typeIndices = &npuIndices;
301+
break;
302+
default:
303+
break;
304+
}
305+
306+
if (typeIndices && spec.typeIndex < typeIndices->size()) {
307+
auto typeIndexPair = std::make_pair(spec.driverType, spec.typeIndex);
308+
if (usedTypeIndices.find(typeIndexPair) == usedTypeIndices.end()) {
309+
uint32_t globalIdx = (*typeIndices)[spec.typeIndex];
310+
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
311+
orderedDrivers.push_back(originalDrivers[globalIdx]);
312+
usedGlobalIndices.insert(globalIdx);
313+
usedTypeIndices.insert(typeIndexPair);
314+
}
315+
}
316+
}
317+
}
318+
break;
319+
}
320+
}
321+
322+
// Add remaining drivers in their original order
323+
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
324+
if (usedGlobalIndices.find(i) == usedGlobalIndices.end()) {
325+
orderedDrivers.push_back(originalDrivers[i]);
326+
}
327+
}
328+
329+
// Replace the original driver vector with the ordered one
330+
*drivers = orderedDrivers;
331+
332+
if (debugTraceEnabled) {
333+
std::string message = "driverOrdering: Drivers after ZEL_DRIVERS_ORDER:";
334+
for (uint32_t i = 0; i < drivers->size(); ++i) {
335+
message += "\n[" + std::to_string(i) + "] Driver Type: " + std::to_string((*drivers)[i].driverType) + " Driver Name: " + (*drivers)[i].name;
336+
}
337+
debug_trace_message(message, "");
338+
}
339+
}
340+
75341
bool context_t::driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly) {
76342
ze_init_driver_type_desc_t permissiveDesc = {};
77343
permissiveDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
@@ -246,6 +512,10 @@ namespace loader
246512
}
247513
debug_trace_message(message, "");
248514
}
515+
516+
// Apply driver ordering based on ZEL_DRIVERS_ORDER environment variable
517+
driverOrdering(drivers);
518+
249519
return true;
250520
}
251521

@@ -577,7 +847,7 @@ namespace loader
577847
GET_FUNCTION_PTR(validationLayer, "zelLoaderGetVersion"));
578848
zel_component_version_t compVersion;
579849
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
580-
{
850+
{
581851
compVersions.push_back(compVersion);
582852
}
583853
} else if (debugTraceEnabled) {
@@ -602,7 +872,7 @@ namespace loader
602872
GET_FUNCTION_PTR(tracingLayer, "zelLoaderGetVersion"));
603873
zel_component_version_t compVersion;
604874
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
605-
{
875+
{
606876
compVersions.push_back(compVersion);
607877
}
608878
} else if (debugTraceEnabled) {

source/loader/ze_loader_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace loader
3838
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
3939
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
4040
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
41+
ZEL_DRIVER_TYPE_NPU = 5, ///< The driver has NPU devices only
4142
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff
4243

4344
} zel_driver_type_t;
@@ -150,6 +151,7 @@ namespace loader
150151
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
151152
void add_loader_version();
152153
bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly);
154+
void driverOrdering(driver_vector_t *drivers);
153155
~context_t();
154156
bool intercept_enabled = false;
155157
bool debugTraceEnabled = false;

0 commit comments

Comments
 (0)