|
5 | 5 | * SPDX-License-Identifier: MIT
|
6 | 6 | *
|
7 | 7 | */
|
8 |
| -#include "ze_loader_internal.h" |
| 8 | +#include "ze_loader_utils.h" |
9 | 9 |
|
10 | 10 | #include "driver_discovery.h"
|
11 | 11 | #include <iostream>
|
12 |
| -#include <string> |
13 |
| -#include <vector> |
14 |
| -#include <map> |
15 |
| -#include <set> |
16 |
| -#include <sstream> |
17 |
| -#include <cstdlib> |
18 |
| -#include <algorithm> |
19 | 12 |
|
20 | 13 | #ifdef __linux__
|
21 | 14 | #include <unistd.h>
|
@@ -79,92 +72,6 @@ namespace loader
|
79 | 72 | return a.driverType < b.driverType;
|
80 | 73 | }
|
81 | 74 |
|
82 |
| - // Helper function to map driver type string to enum |
83 |
| - zel_driver_type_t stringToDriverType(const std::string& typeStr) { |
84 |
| - if (typeStr == "DISCRETE_GPU_ONLY") { |
85 |
| - return ZEL_DRIVER_TYPE_DISCRETE_GPU; |
86 |
| - } else if (typeStr == "GPU") { |
87 |
| - return ZEL_DRIVER_TYPE_GPU; |
88 |
| - } else if (typeStr == "INTEGRATED_GPU_ONLY") { |
89 |
| - return ZEL_DRIVER_TYPE_INTEGRATED_GPU; |
90 |
| - } else if (typeStr == "NPU") { |
91 |
| - return ZEL_DRIVER_TYPE_NPU; |
92 |
| - } |
93 |
| - return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid |
94 |
| - } |
95 |
| - |
96 |
| - // Helper function to trim whitespace |
97 |
| - std::string trim(const std::string& str) { |
98 |
| - const std::string whitespace = " \t\n\r\f\v"; |
99 |
| - size_t start = str.find_first_not_of(whitespace); |
100 |
| - if (start == std::string::npos) return ""; |
101 |
| - size_t end = str.find_last_not_of(whitespace); |
102 |
| - return str.substr(start, end - start + 1); |
103 |
| - } |
104 |
| - |
105 |
| - // Structure to hold parsed ordering instructions |
106 |
| - struct DriverOrderSpec { |
107 |
| - enum Type { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX } type; |
108 |
| - uint32_t globalIndex = 0; |
109 |
| - zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32; |
110 |
| - uint32_t typeIndex = 0; |
111 |
| - }; |
112 |
| - |
113 |
| - // Parse ZEL_DRIVERS_ORDER environment variable |
114 |
| - std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) { |
115 |
| - std::vector<DriverOrderSpec> specs; |
116 |
| - |
117 |
| - // Split by comma |
118 |
| - std::vector<std::string> tokens; |
119 |
| - std::stringstream ss(orderStr); |
120 |
| - std::string token; |
121 |
| - |
122 |
| - while (std::getline(ss, token, ',')) { |
123 |
| - token = trim(token); |
124 |
| - if (token.empty()) continue; |
125 |
| - |
126 |
| - DriverOrderSpec spec; |
127 |
| - |
128 |
| - // Check if it contains a colon (type:index format) |
129 |
| - size_t colonPos = token.find(':'); |
130 |
| - if (colonPos != std::string::npos) { |
131 |
| - // Format: <driver_type>:<driver_index> |
132 |
| - std::string typeStr = trim(token.substr(0, colonPos)); |
133 |
| - std::string indexStr = trim(token.substr(colonPos + 1)); |
134 |
| - |
135 |
| - spec.driverType = stringToDriverType(typeStr); |
136 |
| - if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) { |
137 |
| - continue; // Invalid driver type, skip |
138 |
| - } |
139 |
| - |
140 |
| - try { |
141 |
| - spec.typeIndex = std::stoul(indexStr); |
142 |
| - spec.type = DriverOrderSpec::BY_TYPE_AND_INDEX; |
143 |
| - specs.push_back(spec); |
144 |
| - } catch (const std::exception&) { |
145 |
| - // Invalid index, skip |
146 |
| - continue; |
147 |
| - } |
148 |
| - } else { |
149 |
| - // Check if it's a pure number (global index) or driver type |
150 |
| - try { |
151 |
| - spec.globalIndex = std::stoul(token); |
152 |
| - spec.type = DriverOrderSpec::BY_GLOBAL_INDEX; |
153 |
| - specs.push_back(spec); |
154 |
| - } catch (const std::exception&) { |
155 |
| - // Not a number, try as driver type |
156 |
| - spec.driverType = stringToDriverType(token); |
157 |
| - if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) { |
158 |
| - spec.type = DriverOrderSpec::BY_TYPE; |
159 |
| - specs.push_back(spec); |
160 |
| - } |
161 |
| - } |
162 |
| - } |
163 |
| - } |
164 |
| - |
165 |
| - return specs; |
166 |
| - } |
167 |
| - |
168 | 75 | void context_t::driverOrdering(driver_vector_t *drivers) {
|
169 | 76 | std::string orderStr = getenv_string("ZEL_DRIVERS_ORDER");
|
170 | 77 | if (orderStr.empty()) {
|
@@ -241,15 +148,15 @@ namespace loader
|
241 | 148 | // Apply ordering specifications
|
242 | 149 | for (const auto& spec : specs) {
|
243 | 150 | switch (spec.type) {
|
244 |
| - case DriverOrderSpec::BY_GLOBAL_INDEX: |
| 151 | + case DriverOrderSpecType::BY_GLOBAL_INDEX: |
245 | 152 | if (spec.globalIndex < originalDrivers.size() &&
|
246 | 153 | usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) {
|
247 | 154 | orderedDrivers.push_back(originalDrivers[spec.globalIndex]);
|
248 | 155 | usedGlobalIndices.insert(spec.globalIndex);
|
249 | 156 | }
|
250 | 157 | break;
|
251 | 158 |
|
252 |
| - case DriverOrderSpec::BY_TYPE: |
| 159 | + case DriverOrderSpecType::BY_TYPE: |
253 | 160 | // Add all drivers of this type that haven't been used
|
254 | 161 | {
|
255 | 162 | std::vector<uint32_t>* typeIndices = nullptr;
|
@@ -282,7 +189,7 @@ namespace loader
|
282 | 189 | }
|
283 | 190 | break;
|
284 | 191 |
|
285 |
| - case DriverOrderSpec::BY_TYPE_AND_INDEX: |
| 192 | + case DriverOrderSpecType::BY_TYPE_AND_INDEX: |
286 | 193 | {
|
287 | 194 | std::vector<uint32_t>* typeIndices = nullptr;
|
288 | 195 | switch (spec.driverType) {
|
|
0 commit comments