@@ -305,6 +305,14 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
305
305
}
306
306
pending_allocations.clear ();
307
307
308
+ // Apply any pending raw pointer arguments (USM pointers) for this device.
309
+ for (auto &pending : pending_pointer_args) {
310
+ void *Ptr = const_cast <void *>(pending.ptrArgValue );
311
+ ZE2UR_CALL (zeKernelSetArgumentValue,
312
+ (hZeKernel, pending.argIndex , sizeof (void *), &Ptr));
313
+ }
314
+ pending_pointer_args.clear ();
315
+
308
316
return UR_RESULT_SUCCESS;
309
317
}
310
318
@@ -319,6 +327,18 @@ ur_result_t ur_kernel_handle_t_::addPendingMemoryAllocation(
319
327
return UR_RESULT_SUCCESS;
320
328
}
321
329
330
+ ur_result_t
331
+ ur_kernel_handle_t_::addPendingPointerArgument (uint32_t argIndex,
332
+ const void *pArgValue) {
333
+ if (argIndex > zeCommonProperties->numKernelArgs - 1 ) {
334
+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
335
+ }
336
+
337
+ pending_pointer_args.push_back ({argIndex, pArgValue});
338
+
339
+ return UR_RESULT_SUCCESS;
340
+ }
341
+
322
342
std::vector<char > ur_kernel_handle_t_::getSourceAttributes () const {
323
343
uint32_t size;
324
344
ZE2UR_CALL_THROWS (zeKernelGetSourceAttributes,
@@ -405,14 +425,17 @@ ur_result_t urKernelSetArgPointer(
405
425
ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
406
426
uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
407
427
const ur_kernel_arg_pointer_properties_t
408
- *pProperties , // /< [in][optional] argument properties
428
+ *, // /< [in][optional] argument properties
409
429
const void
410
430
*pArgValue // /< [in] argument value represented as matching arg type.
411
431
) try {
412
432
TRACK_SCOPE_LATENCY (" urKernelSetArgPointer" );
413
433
414
434
std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
415
- return hKernel->setArgPointer (argIndex, pProperties, pArgValue);
435
+ // Store the raw pointer value and defer setting the argument until
436
+ // we know the device where kernel is being submitted.
437
+ hKernel->addPendingPointerArgument (argIndex, pArgValue);
438
+ return UR_RESULT_SUCCESS;
416
439
} catch (...) {
417
440
return exceptionToResult (std::current_exception ());
418
441
}
0 commit comments