Skip to content

Commit d9a64e1

Browse files
committed
Move helper functions to new header and update tests
Signed-off-by: Neil R. Spruit <[email protected]>
1 parent 18d513c commit d9a64e1

File tree

4 files changed

+221
-300
lines changed

4 files changed

+221
-300
lines changed

source/loader/ze_loader.cpp

Lines changed: 4 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,10 @@
55
* SPDX-License-Identifier: MIT
66
*
77
*/
8-
#include "ze_loader_internal.h"
8+
#include "ze_loader_utils.h"
99

1010
#include "driver_discovery.h"
1111
#include <iostream>
12-
#include <string>
13-
#include <vector>
14-
#include <map>
15-
#include <set>
16-
#include <sstream>
17-
#include <cstdlib>
18-
#include <algorithm>
1912

2013
#ifdef __linux__
2114
#include <unistd.h>
@@ -79,92 +72,6 @@ namespace loader
7972
return a.driverType < b.driverType;
8073
}
8174

82-
// Helper function to map driver type string to enum
83-
zel_driver_type_t stringToDriverType(const std::string& typeStr) {
84-
if (typeStr == "DISCRETE_GPU_ONLY") {
85-
return ZEL_DRIVER_TYPE_DISCRETE_GPU;
86-
} else if (typeStr == "GPU") {
87-
return ZEL_DRIVER_TYPE_GPU;
88-
} else if (typeStr == "INTEGRATED_GPU_ONLY") {
89-
return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
90-
} else if (typeStr == "NPU") {
91-
return ZEL_DRIVER_TYPE_NPU;
92-
}
93-
return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
94-
}
95-
96-
// Helper function to trim whitespace
97-
std::string trim(const std::string& str) {
98-
const std::string whitespace = " \t\n\r\f\v";
99-
size_t start = str.find_first_not_of(whitespace);
100-
if (start == std::string::npos) return "";
101-
size_t end = str.find_last_not_of(whitespace);
102-
return str.substr(start, end - start + 1);
103-
}
104-
105-
// Structure to hold parsed ordering instructions
106-
struct DriverOrderSpec {
107-
enum Type { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX } type;
108-
uint32_t globalIndex = 0;
109-
zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
110-
uint32_t typeIndex = 0;
111-
};
112-
113-
// Parse ZEL_DRIVERS_ORDER environment variable
114-
std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) {
115-
std::vector<DriverOrderSpec> specs;
116-
117-
// Split by comma
118-
std::vector<std::string> tokens;
119-
std::stringstream ss(orderStr);
120-
std::string token;
121-
122-
while (std::getline(ss, token, ',')) {
123-
token = trim(token);
124-
if (token.empty()) continue;
125-
126-
DriverOrderSpec spec;
127-
128-
// Check if it contains a colon (type:index format)
129-
size_t colonPos = token.find(':');
130-
if (colonPos != std::string::npos) {
131-
// Format: <driver_type>:<driver_index>
132-
std::string typeStr = trim(token.substr(0, colonPos));
133-
std::string indexStr = trim(token.substr(colonPos + 1));
134-
135-
spec.driverType = stringToDriverType(typeStr);
136-
if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
137-
continue; // Invalid driver type, skip
138-
}
139-
140-
try {
141-
spec.typeIndex = std::stoul(indexStr);
142-
spec.type = DriverOrderSpec::BY_TYPE_AND_INDEX;
143-
specs.push_back(spec);
144-
} catch (const std::exception&) {
145-
// Invalid index, skip
146-
continue;
147-
}
148-
} else {
149-
// Check if it's a pure number (global index) or driver type
150-
try {
151-
spec.globalIndex = std::stoul(token);
152-
spec.type = DriverOrderSpec::BY_GLOBAL_INDEX;
153-
specs.push_back(spec);
154-
} catch (const std::exception&) {
155-
// Not a number, try as driver type
156-
spec.driverType = stringToDriverType(token);
157-
if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
158-
spec.type = DriverOrderSpec::BY_TYPE;
159-
specs.push_back(spec);
160-
}
161-
}
162-
}
163-
}
164-
165-
return specs;
166-
}
167-
16875
void context_t::driverOrdering(driver_vector_t *drivers) {
16976
std::string orderStr = getenv_string("ZEL_DRIVERS_ORDER");
17077
if (orderStr.empty()) {
@@ -241,15 +148,15 @@ namespace loader
241148
// Apply ordering specifications
242149
for (const auto& spec : specs) {
243150
switch (spec.type) {
244-
case DriverOrderSpec::BY_GLOBAL_INDEX:
151+
case DriverOrderSpecType::BY_GLOBAL_INDEX:
245152
if (spec.globalIndex < originalDrivers.size() &&
246153
usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) {
247154
orderedDrivers.push_back(originalDrivers[spec.globalIndex]);
248155
usedGlobalIndices.insert(spec.globalIndex);
249156
}
250157
break;
251158

252-
case DriverOrderSpec::BY_TYPE:
159+
case DriverOrderSpecType::BY_TYPE:
253160
// Add all drivers of this type that haven't been used
254161
{
255162
std::vector<uint32_t>* typeIndices = nullptr;
@@ -282,7 +189,7 @@ namespace loader
282189
}
283190
break;
284191

285-
case DriverOrderSpec::BY_TYPE_AND_INDEX:
192+
case DriverOrderSpecType::BY_TYPE_AND_INDEX:
286193
{
287194
std::vector<uint32_t>* typeIndices = nullptr;
288195
switch (spec.driverType) {

source/loader/ze_loader_utils.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
*
3+
* Copyright (C) 2025 Intel Corporation
4+
*
5+
* SPDX-License-Identifier: MIT
6+
*
7+
*/
8+
#pragma once
9+
#include "ze_loader_internal.h"
10+
#include <string>
11+
#include <vector>
12+
#include <map>
13+
#include <set>
14+
#include <sstream>
15+
#include <cstdlib>
16+
#include <algorithm>
17+
18+
19+
namespace loader
20+
{
21+
// Helper function to map driver type string to enum
22+
inline zel_driver_type_t stringToDriverType(const std::string& typeStr) {
23+
if (typeStr == "DISCRETE_GPU_ONLY") {
24+
return ZEL_DRIVER_TYPE_DISCRETE_GPU;
25+
} else if (typeStr == "GPU") {
26+
return ZEL_DRIVER_TYPE_GPU;
27+
} else if (typeStr == "INTEGRATED_GPU_ONLY") {
28+
return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
29+
} else if (typeStr == "NPU") {
30+
return ZEL_DRIVER_TYPE_NPU;
31+
}
32+
return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
33+
}
34+
35+
// Helper function to trim whitespace
36+
inline std::string trim(const std::string& str) {
37+
const std::string whitespace = " \t\n\r\f\v";
38+
size_t start = str.find_first_not_of(whitespace);
39+
if (start == std::string::npos) return "";
40+
size_t end = str.find_last_not_of(whitespace);
41+
return str.substr(start, end - start + 1);
42+
}
43+
44+
enum DriverOrderSpecType { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX };
45+
46+
// Structure to hold parsed ordering instructions
47+
struct DriverOrderSpec {
48+
DriverOrderSpecType type;
49+
uint32_t globalIndex = 0;
50+
zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
51+
uint32_t typeIndex = 0;
52+
};
53+
54+
// Parse ZEL_DRIVERS_ORDER environment variable
55+
inline std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) {
56+
std::vector<DriverOrderSpec> specs;
57+
58+
// Split by comma
59+
std::vector<std::string> tokens;
60+
std::stringstream ss(orderStr);
61+
std::string token;
62+
63+
while (std::getline(ss, token, ',')) {
64+
token = trim(token);
65+
if (token.empty()) continue;
66+
67+
DriverOrderSpec spec;
68+
69+
// Check if it contains a colon (type:index format)
70+
size_t colonPos = token.find(':');
71+
if (colonPos != std::string::npos) {
72+
// Format: <driver_type>:<driver_index>
73+
std::string typeStr = trim(token.substr(0, colonPos));
74+
std::string indexStr = trim(token.substr(colonPos + 1));
75+
76+
spec.driverType = stringToDriverType(typeStr);
77+
if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
78+
continue; // Invalid driver type, skip
79+
}
80+
81+
try {
82+
spec.typeIndex = std::stoul(indexStr);
83+
spec.type = DriverOrderSpecType::BY_TYPE_AND_INDEX;
84+
specs.push_back(spec);
85+
} catch (const std::exception&) {
86+
// Invalid index, skip
87+
continue;
88+
}
89+
} else {
90+
// Check if it's a pure number (global index) or driver type
91+
try {
92+
spec.globalIndex = std::stoul(token);
93+
spec.type = DriverOrderSpecType::BY_GLOBAL_INDEX;
94+
specs.push_back(spec);
95+
} catch (const std::exception&) {
96+
// Not a number, try as driver type
97+
spec.driverType = stringToDriverType(token);
98+
if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
99+
spec.type = DriverOrderSpecType::BY_TYPE;
100+
specs.push_back(spec);
101+
}
102+
}
103+
}
104+
}
105+
106+
return specs;
107+
}
108+
} // namespace loader

0 commit comments

Comments
 (0)