Skip to content

Commit f961f83

Browse files
committed
[UR][L0 v2] Set pointer kernel arguments only for queue's associated device
1 parent 2af08ff commit f961f83

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

unified-runtime/source/adapters/level_zero/v2/kernel.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,14 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
305305
}
306306
pending_allocations.clear();
307307

308+
// Apply any pending raw pointer arguments (USM pointers) for this device.
309+
for (auto &pending : pending_pointer_args) {
310+
void *Ptr = const_cast<void *>(pending.ptrArgValue);
311+
ZE2UR_CALL(zeKernelSetArgumentValue,
312+
(hZeKernel, pending.argIndex, sizeof(void *), &Ptr));
313+
}
314+
pending_pointer_args.clear();
315+
308316
return UR_RESULT_SUCCESS;
309317
}
310318

@@ -319,6 +327,18 @@ ur_result_t ur_kernel_handle_t_::addPendingMemoryAllocation(
319327
return UR_RESULT_SUCCESS;
320328
}
321329

330+
ur_result_t
331+
ur_kernel_handle_t_::addPendingPointerArgument(uint32_t argIndex,
332+
const void *pArgValue) {
333+
if (argIndex > zeCommonProperties->numKernelArgs - 1) {
334+
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
335+
}
336+
337+
pending_pointer_args.push_back({argIndex, pArgValue});
338+
339+
return UR_RESULT_SUCCESS;
340+
}
341+
322342
std::vector<char> ur_kernel_handle_t_::getSourceAttributes() const {
323343
uint32_t size;
324344
ZE2UR_CALL_THROWS(zeKernelGetSourceAttributes,
@@ -405,14 +425,17 @@ ur_result_t urKernelSetArgPointer(
405425
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
406426
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
407427
const ur_kernel_arg_pointer_properties_t
408-
*pProperties, ///< [in][optional] argument properties
428+
*, ///< [in][optional] argument properties
409429
const void
410430
*pArgValue ///< [in] argument value represented as matching arg type.
411431
) try {
412432
TRACK_SCOPE_LATENCY("urKernelSetArgPointer");
413433

414434
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
415-
return hKernel->setArgPointer(argIndex, pProperties, pArgValue);
435+
// Store the raw pointer value and defer setting the argument until
436+
// we know the device where kernel is being submitted.
437+
hKernel->addPendingPointerArgument(argIndex, pArgValue);
438+
return UR_RESULT_SUCCESS;
416439
} catch (...) {
417440
return exceptionToResult(std::current_exception());
418441
}

unified-runtime/source/adapters/level_zero/v2/kernel.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ struct ur_kernel_handle_t_ : ur_object {
8282
ur_result_t
8383
addPendingMemoryAllocation(pending_memory_allocation_t allocation);
8484

85+
// Add a pending pointer argument for which device is not yet known.
86+
ur_result_t addPendingPointerArgument(uint32_t argIndex,
87+
const void *pArgValue);
88+
8589
// Set all required values for the kernel before submission (including pending
8690
// memory allocations).
8791
ur_result_t prepareForSubmission(ur_context_handle_t hContext,
@@ -115,6 +119,14 @@ struct ur_kernel_handle_t_ : ur_object {
115119

116120
std::vector<pending_memory_allocation_t> pending_allocations;
117121

122+
struct pending_pointer_arg_t {
123+
uint32_t argIndex;
124+
const void *ptrArgValue;
125+
};
126+
127+
// Pointer arguments that need to be applied per-device at submission time.
128+
std::vector<pending_pointer_arg_t> pending_pointer_args;
129+
118130
void completeInitialization();
119131

120132
// pointer to any non-null kernel in deviceKernels

0 commit comments

Comments
 (0)