9
9
10
10
#include " driver_discovery.h"
11
11
#include < iostream>
12
+ #include < string>
13
+ #include < vector>
14
+ #include < map>
15
+ #include < set>
16
+ #include < sstream>
17
+ #include < cstdlib>
18
+ #include < algorithm>
12
19
13
20
#ifdef __linux__
14
21
#include < unistd.h>
@@ -72,6 +79,265 @@ namespace loader
72
79
return a.driverType < b.driverType ;
73
80
}
74
81
82
+ // Helper function to map driver type string to enum
83
+ zel_driver_type_t stringToDriverType (const std::string& typeStr) {
84
+ if (typeStr == " DISCRETE_GPU_ONLY" ) {
85
+ return ZEL_DRIVER_TYPE_DISCRETE_GPU;
86
+ } else if (typeStr == " GPU" ) {
87
+ return ZEL_DRIVER_TYPE_GPU;
88
+ } else if (typeStr == " INTEGRATED_GPU_ONLY" ) {
89
+ return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
90
+ } else if (typeStr == " NPU" ) {
91
+ return ZEL_DRIVER_TYPE_NPU;
92
+ }
93
+ return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
94
+ }
95
+
96
+ // Helper function to trim whitespace
97
+ std::string trim (const std::string& str) {
98
+ const std::string whitespace = " \t\n\r\f\v " ;
99
+ size_t start = str.find_first_not_of (whitespace);
100
+ if (start == std::string::npos) return " " ;
101
+ size_t end = str.find_last_not_of (whitespace);
102
+ return str.substr (start, end - start + 1 );
103
+ }
104
+
105
+ // Structure to hold parsed ordering instructions
106
+ struct DriverOrderSpec {
107
+ enum Type { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX } type;
108
+ uint32_t globalIndex = 0 ;
109
+ zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
110
+ uint32_t typeIndex = 0 ;
111
+ };
112
+
113
+ // Parse ZEL_DRIVERS_ORDER environment variable
114
+ std::vector<DriverOrderSpec> parseDriverOrder (const std::string& orderStr) {
115
+ std::vector<DriverOrderSpec> specs;
116
+
117
+ // Split by comma
118
+ std::vector<std::string> tokens;
119
+ std::stringstream ss (orderStr);
120
+ std::string token;
121
+
122
+ while (std::getline (ss, token, ' ,' )) {
123
+ token = trim (token);
124
+ if (token.empty ()) continue ;
125
+
126
+ DriverOrderSpec spec;
127
+
128
+ // Check if it contains a colon (type:index format)
129
+ size_t colonPos = token.find (' :' );
130
+ if (colonPos != std::string::npos) {
131
+ // Format: <driver_type>:<driver_index>
132
+ std::string typeStr = trim (token.substr (0 , colonPos));
133
+ std::string indexStr = trim (token.substr (colonPos + 1 ));
134
+
135
+ spec.driverType = stringToDriverType (typeStr);
136
+ if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
137
+ continue ; // Invalid driver type, skip
138
+ }
139
+
140
+ try {
141
+ spec.typeIndex = std::stoul (indexStr);
142
+ spec.type = DriverOrderSpec::BY_TYPE_AND_INDEX;
143
+ specs.push_back (spec);
144
+ } catch (const std::exception&) {
145
+ // Invalid index, skip
146
+ continue ;
147
+ }
148
+ } else {
149
+ // Check if it's a pure number (global index) or driver type
150
+ try {
151
+ spec.globalIndex = std::stoul (token);
152
+ spec.type = DriverOrderSpec::BY_GLOBAL_INDEX;
153
+ specs.push_back (spec);
154
+ } catch (const std::exception&) {
155
+ // Not a number, try as driver type
156
+ spec.driverType = stringToDriverType (token);
157
+ if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
158
+ spec.type = DriverOrderSpec::BY_TYPE;
159
+ specs.push_back (spec);
160
+ }
161
+ }
162
+ }
163
+ }
164
+
165
+ return specs;
166
+ }
167
+
168
+ void context_t::driverOrdering (driver_vector_t *drivers) {
169
+ std::string orderStr = getenv_string (" ZEL_DRIVERS_ORDER" );
170
+ if (orderStr.empty ()) {
171
+ return ; // No ordering specified
172
+ }
173
+
174
+ std::vector<DriverOrderSpec> specs = parseDriverOrder (orderStr);
175
+
176
+ if (specs.empty ()) {
177
+ if (debugTraceEnabled) {
178
+ std::string message = " driverOrdering: ZEL_DRIVERS_ORDER parsing failed or empty: " + orderStr;
179
+ debug_trace_message (message, " " );
180
+ }
181
+ return ;
182
+ }
183
+
184
+ if (debugTraceEnabled) {
185
+ std::string message = " driverOrdering:ZEL_DRIVERS_ORDER parsing successful: " + orderStr + " , specs count: " + std::to_string (specs.size ());
186
+ debug_trace_message (message, " " );
187
+ }
188
+
189
+ // Create a copy of the original driver vector for reference
190
+ driver_vector_t originalDrivers = *drivers;
191
+
192
+ driver_vector_t discreteGPUDrivers;
193
+ driver_vector_t integratedGPUDrivers;
194
+ driver_vector_t npuDrivers;
195
+ driver_vector_t gpuDrivers;
196
+
197
+ std::vector<uint32_t > discreteGPUIndices;
198
+ std::vector<uint32_t > integratedGPUIndices;
199
+ std::vector<uint32_t > npuIndices;
200
+ std::vector<uint32_t > gpuIndices;
201
+
202
+ // Group drivers by type and track their original indices
203
+ for (uint32_t i = 0 ; i < originalDrivers.size (); ++i) {
204
+ const auto & driver = originalDrivers[i];
205
+ switch (driver.driverType ) {
206
+ case ZEL_DRIVER_TYPE_DISCRETE_GPU:
207
+ discreteGPUDrivers.push_back (driver);
208
+ discreteGPUIndices.push_back (i);
209
+ break ;
210
+ case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
211
+ integratedGPUDrivers.push_back (driver);
212
+ integratedGPUIndices.push_back (i);
213
+ break ;
214
+ case ZEL_DRIVER_TYPE_GPU:
215
+ gpuDrivers.push_back (driver);
216
+ gpuIndices.push_back (i);
217
+ break ;
218
+ case ZEL_DRIVER_TYPE_NPU:
219
+ npuDrivers.push_back (driver);
220
+ npuIndices.push_back (i);
221
+ break ;
222
+ case ZEL_DRIVER_TYPE_OTHER:
223
+ npuDrivers.push_back (driver);
224
+ npuIndices.push_back (i);
225
+ break ;
226
+ case ZEL_DRIVER_TYPE_MIXED:
227
+ // Mixed drivers go to gpuDrivers
228
+ gpuDrivers.push_back (driver);
229
+ gpuIndices.push_back (i);
230
+ break ;
231
+ default :
232
+ break ;
233
+ }
234
+ }
235
+
236
+ // Create new ordered driver vector
237
+ driver_vector_t orderedDrivers;
238
+ std::set<uint32_t > usedGlobalIndices;
239
+ std::set<std::pair<zel_driver_type_t , uint32_t >> usedTypeIndices;
240
+
241
+ // Apply ordering specifications
242
+ for (const auto & spec : specs) {
243
+ switch (spec.type ) {
244
+ case DriverOrderSpec::BY_GLOBAL_INDEX:
245
+ if (spec.globalIndex < originalDrivers.size () &&
246
+ usedGlobalIndices.find (spec.globalIndex ) == usedGlobalIndices.end ()) {
247
+ orderedDrivers.push_back (originalDrivers[spec.globalIndex ]);
248
+ usedGlobalIndices.insert (spec.globalIndex );
249
+ }
250
+ break ;
251
+
252
+ case DriverOrderSpec::BY_TYPE:
253
+ // Add all drivers of this type that haven't been used
254
+ {
255
+ std::vector<uint32_t >* typeIndices = nullptr ;
256
+ switch (spec.driverType ) {
257
+ case ZEL_DRIVER_TYPE_DISCRETE_GPU:
258
+ typeIndices = &discreteGPUIndices;
259
+ break ;
260
+ case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
261
+ typeIndices = &integratedGPUIndices;
262
+ break ;
263
+ case ZEL_DRIVER_TYPE_GPU:
264
+ typeIndices = &gpuIndices;
265
+ break ;
266
+ case ZEL_DRIVER_TYPE_NPU:
267
+ case ZEL_DRIVER_TYPE_OTHER:
268
+ typeIndices = &npuIndices;
269
+ break ;
270
+ default :
271
+ break ;
272
+ }
273
+
274
+ if (typeIndices) {
275
+ for (uint32_t globalIdx : *typeIndices) {
276
+ if (usedGlobalIndices.find (globalIdx) == usedGlobalIndices.end ()) {
277
+ orderedDrivers.push_back (originalDrivers[globalIdx]);
278
+ usedGlobalIndices.insert (globalIdx);
279
+ }
280
+ }
281
+ }
282
+ }
283
+ break ;
284
+
285
+ case DriverOrderSpec::BY_TYPE_AND_INDEX:
286
+ {
287
+ std::vector<uint32_t >* typeIndices = nullptr ;
288
+ switch (spec.driverType ) {
289
+ case ZEL_DRIVER_TYPE_DISCRETE_GPU:
290
+ typeIndices = &discreteGPUIndices;
291
+ break ;
292
+ case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
293
+ typeIndices = &integratedGPUIndices;
294
+ break ;
295
+ case ZEL_DRIVER_TYPE_GPU:
296
+ typeIndices = &gpuIndices;
297
+ break ;
298
+ case ZEL_DRIVER_TYPE_NPU:
299
+ case ZEL_DRIVER_TYPE_OTHER:
300
+ typeIndices = &npuIndices;
301
+ break ;
302
+ default :
303
+ break ;
304
+ }
305
+
306
+ if (typeIndices && spec.typeIndex < typeIndices->size ()) {
307
+ auto typeIndexPair = std::make_pair (spec.driverType , spec.typeIndex );
308
+ if (usedTypeIndices.find (typeIndexPair) == usedTypeIndices.end ()) {
309
+ uint32_t globalIdx = (*typeIndices)[spec.typeIndex ];
310
+ if (usedGlobalIndices.find (globalIdx) == usedGlobalIndices.end ()) {
311
+ orderedDrivers.push_back (originalDrivers[globalIdx]);
312
+ usedGlobalIndices.insert (globalIdx);
313
+ usedTypeIndices.insert (typeIndexPair);
314
+ }
315
+ }
316
+ }
317
+ }
318
+ break ;
319
+ }
320
+ }
321
+
322
+ // Add remaining drivers in their original order
323
+ for (uint32_t i = 0 ; i < originalDrivers.size (); ++i) {
324
+ if (usedGlobalIndices.find (i) == usedGlobalIndices.end ()) {
325
+ orderedDrivers.push_back (originalDrivers[i]);
326
+ }
327
+ }
328
+
329
+ // Replace the original driver vector with the ordered one
330
+ *drivers = orderedDrivers;
331
+
332
+ if (debugTraceEnabled) {
333
+ std::string message = " driverOrdering: Drivers after ZEL_DRIVERS_ORDER:" ;
334
+ for (uint32_t i = 0 ; i < drivers->size (); ++i) {
335
+ message += " \n [" + std::to_string (i) + " ] Driver Type: " + std::to_string ((*drivers)[i].driverType ) + " Driver Name: " + (*drivers)[i].name ;
336
+ }
337
+ debug_trace_message (message, " " );
338
+ }
339
+ }
340
+
75
341
bool context_t::driverSorting (driver_vector_t *drivers, ze_init_driver_type_desc_t * desc, bool sysmanOnly) {
76
342
ze_init_driver_type_desc_t permissiveDesc = {};
77
343
permissiveDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
@@ -246,6 +512,10 @@ namespace loader
246
512
}
247
513
debug_trace_message (message, " " );
248
514
}
515
+
516
+ // Apply driver ordering based on ZEL_DRIVERS_ORDER environment variable
517
+ driverOrdering (drivers);
518
+
249
519
return true ;
250
520
}
251
521
@@ -577,7 +847,7 @@ namespace loader
577
847
GET_FUNCTION_PTR (validationLayer, " zelLoaderGetVersion" ));
578
848
zel_component_version_t compVersion;
579
849
if (getVersion && ZE_RESULT_SUCCESS == getVersion (&compVersion))
580
- {
850
+ {
581
851
compVersions.push_back (compVersion);
582
852
}
583
853
} else if (debugTraceEnabled) {
@@ -602,7 +872,7 @@ namespace loader
602
872
GET_FUNCTION_PTR (tracingLayer, " zelLoaderGetVersion" ));
603
873
zel_component_version_t compVersion;
604
874
if (getVersion && ZE_RESULT_SUCCESS == getVersion (&compVersion))
605
- {
875
+ {
606
876
compVersions.push_back (compVersion);
607
877
}
608
878
} else if (debugTraceEnabled) {
0 commit comments