Skip to content

Commit 09d5064

Browse files
committed
[UR][L0] Set pointer kernel arguments only for queue's associated device
Ensure that pointer kernel arguments are set only for the device associated with the queue being used for kernel launch. Previously, arguments were set for all devices in the kernel's device map, which was unnecessary and potentially incorrect when launching on a specific device.
1 parent bdd0fd1 commit 09d5064

File tree

6 files changed

+145
-38
lines changed

6 files changed

+145
-38
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// UNSUPPORTED: level_zero_v2_adapter
5+
// UNSUPPORTED-TRACKER: CMPLRLLVM-67039
6+
7+
// Test that usm device pointer can be used in a kernel compiled for a context
8+
// with multiple devices.
9+
10+
#include <iostream>
11+
#include <sycl/detail/core.hpp>
12+
#include <sycl/kernel_bundle.hpp>
13+
#include <sycl/platform.hpp>
14+
#include <sycl/usm.hpp>
15+
#include <vector>
16+
17+
using namespace sycl;
18+
19+
class AddIdxKernel;
20+
21+
int main() {
22+
sycl::platform plt;
23+
std::vector<sycl::device> devices = plt.get_devices();
24+
if (devices.size() < 2) {
25+
std::cout << "Need at least 2 GPU devices for this test.\n";
26+
return 0;
27+
}
28+
29+
std::vector<sycl::device> ctx_devices{devices[0], devices[1]};
30+
sycl::context ctx(ctx_devices);
31+
32+
constexpr size_t N = 16;
33+
std::vector<std::vector<int>> results(ctx_devices.size(),
34+
std::vector<int>(N, 0));
35+
36+
// Create a kernel bundle compiled for both devices in the context
37+
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
38+
39+
// For each device, create a queue and run a kernel using device USM
40+
for (size_t i = 0; i < ctx_devices.size(); ++i) {
41+
sycl::queue q(ctx, ctx_devices[i]);
42+
int *data = sycl::malloc_device<int>(N, q);
43+
q.fill(data, 1, N).wait();
44+
q.submit([&](sycl::handler &h) {
45+
h.use_kernel_bundle(kb);
46+
h.parallel_for<AddIdxKernel>(
47+
sycl::range<1>(N), [=](sycl::id<1> idx) { data[idx] += idx[0]; });
48+
}).wait();
49+
q.memcpy(results[i].data(), data, N * sizeof(int)).wait();
50+
sycl::free(data, q);
51+
}
52+
53+
for (size_t i = 0; i < ctx_devices.size(); ++i) {
54+
std::cout << "Device " << i << " results: ";
55+
for (size_t j = 0; j < N; ++j) {
56+
if (results[i][j] != 1 + static_cast<int>(j)) {
57+
return -1;
58+
}
59+
std::cout << results[i][j] << " ";
60+
}
61+
}
62+
return 0;
63+
}

unified-runtime/source/adapters/level_zero/command_buffer.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,12 +1004,16 @@ ur_result_t setKernelPendingArguments(
10041004
ze_kernel_handle_t ZeKernel) {
10051005
// If there are any pending arguments set them now.
10061006
for (auto &Arg : PendingArguments) {
1007-
// The ArgValue may be a NULL pointer in which case a NULL value is used for
1008-
// the kernel argument declared as a pointer to global or constant memory.
10091007
char **ZeHandlePtr = nullptr;
1010-
if (Arg.Value) {
1011-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
1012-
nullptr, 0u));
1008+
if (auto MemObjPtr = std::get_if<ur_mem_handle_t>(&Arg.Value)) {
1009+
ur_mem_handle_t MemObj = *MemObjPtr;
1010+
if (MemObj) {
1011+
UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
1012+
nullptr, 0u));
1013+
}
1014+
} else {
1015+
auto Ptr = const_cast<void **>(&std::get<const void *>(Arg.Value));
1016+
ZeHandlePtr = reinterpret_cast<char **>(Ptr);
10131017
}
10141018
ZE2UR_CALL(zeKernelSetArgumentValue,
10151019
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));

unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,25 @@ ur_result_t calculateKernelWorkDimensions(
156156

157157
return UR_RESULT_SUCCESS;
158158
}
159+
160+
ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel,
161+
uint32_t argIndex, size_t argSize,
162+
const void *pArgValue) {
163+
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
164+
// in which case a NULL value will be used as the value for the argument
165+
// declared as a pointer to global or constant memory in the kernel"
166+
//
167+
// We don't know the type of the argument but it seems that the only time
168+
// SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
169+
// is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
170+
if (argSize == sizeof(void *) && pArgValue &&
171+
*(void **)(const_cast<void *>(pArgValue)) == nullptr) {
172+
pArgValue = nullptr;
173+
}
174+
175+
ze_result_t ZeResult = ZE_CALL_NOCHECK(
176+
zeKernelSetArgumentValue, (hZeKernel, argIndex, argSize, pArgValue));
177+
if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT)
178+
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
179+
return ze2urResult(ZeResult);
180+
}

unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,15 @@ inline void postSubmit(ze_kernel_handle_t hZeKernel,
7171
zeKernelSetGlobalOffsetExp(hZeKernel, 0, 0, 0);
7272
}
7373
}
74+
75+
/**
76+
* Helper to set kernel argument for ze_kernel_handle_t.
77+
* @param[in] hZeKernel The handle to the Level-Zero kernel.
78+
* @param[in] argIndex The index of the argument to set.
79+
* @param[in] argSize The size of the argument to set.
80+
* @param[in] pArgValue The pointer to the argument value.
81+
* @return UR_RESULT_SUCCESS or an error code on failure
82+
*/
83+
ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel,
84+
uint32_t argIndex, size_t argSize,
85+
const void *pArgValue);

unified-runtime/source/adapters/level_zero/kernel.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,22 @@ ur_result_t urEnqueueKernelLaunch(
125125

126126
// If there are any pending arguments set them now.
127127
for (auto &Arg : Kernel->PendingArguments) {
128-
// The ArgValue may be a NULL pointer in which case a NULL value is used for
129-
// the kernel argument declared as a pointer to global or constant memory.
128+
// The Arg.Value can be either a ur_mem_handle_t or a raw pointer
129+
// (const void*). Resolve per-device: for mem handles obtain the device
130+
// specific handle, otherwise pass the raw pointer value.
130131
char **ZeHandlePtr = nullptr;
131-
if (Arg.Value) {
132-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
133-
Queue->Device, EventWaitList,
134-
NumEventsInWaitList));
132+
if (auto MemObjPtr = std::get_if<ur_mem_handle_t>(&Arg.Value)) {
133+
ur_mem_handle_t MemObj = *MemObjPtr;
134+
if (MemObj) {
135+
UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
136+
Queue->Device, EventWaitList,
137+
NumEventsInWaitList));
138+
}
139+
} else {
140+
auto Ptr = const_cast<void **>(&std::get<const void *>(Arg.Value));
141+
ZeHandlePtr = reinterpret_cast<char **>(Ptr);
135142
}
136-
ZE2UR_CALL(zeKernelSetArgumentValue,
137-
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
143+
UR_CALL(setArgValueOnZeKernel(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
138144
}
139145
Kernel->PendingArguments.clear();
140146

@@ -422,41 +428,21 @@ ur_result_t urKernelSetArgValue(
422428

423429
UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
424430

425-
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
426-
// in which case a NULL value will be used as the value for the argument
427-
// declared as a pointer to global or constant memory in the kernel"
428-
//
429-
// We don't know the type of the argument but it seems that the only time
430-
// SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
431-
// is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
432-
if (ArgSize == sizeof(void *) && PArgValue &&
433-
*(void **)(const_cast<void *>(PArgValue)) == nullptr) {
434-
PArgValue = nullptr;
435-
}
436-
437431
if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) {
438432
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
439433
}
440434

441435
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
442-
ze_result_t ZeResult = ZE_RESULT_SUCCESS;
443436
if (Kernel->ZeKernelMap.empty()) {
444437
auto ZeKernel = Kernel->ZeKernel;
445-
ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue,
446-
(ZeKernel, ArgIndex, ArgSize, PArgValue));
438+
UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue))
447439
} else {
448440
for (auto It : Kernel->ZeKernelMap) {
449441
auto ZeKernel = It.second;
450-
ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue,
451-
(ZeKernel, ArgIndex, ArgSize, PArgValue));
442+
UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue))
452443
}
453444
}
454-
455-
if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
456-
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
457-
}
458-
459-
return ze2urResult(ZeResult);
445+
return UR_RESULT_SUCCESS;
460446
}
461447

462448
ur_result_t urKernelSetArgLocal(
@@ -732,6 +718,23 @@ ur_result_t urKernelSetArgPointer(
732718
/// [in][optional] SVM pointer to memory location holding the argument
733719
/// value. If null then argument value is considered null.
734720
const void *ArgValue) {
721+
UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
722+
{
723+
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
724+
// In multi-device context instead of setting pointer arguments immediately
725+
// across all device kernels, store them as pending so they can be resolved
726+
// per-device at enqueue time. This ensures the correct handle is used for
727+
// the device of the queue.
728+
if (Kernel->Program->Context->getDevices().size() > 1) {
729+
if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) {
730+
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
731+
}
732+
Kernel->PendingArguments.push_back({ArgIndex, sizeof(const void *),
733+
ArgValue, ur_mem_handle_t_::unknown});
734+
735+
return UR_RESULT_SUCCESS;
736+
}
737+
}
735738

736739
// KernelSetArgValue is expecting a pointer to the argument
737740
UR_CALL(ur::level_zero::urKernelSetArgValue(

unified-runtime/source/adapters/level_zero/kernel.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include <unordered_set>
13+
#include <variant>
1314

1415
#include "common.hpp"
1516
#include "common/ur_ref_count.hpp"
@@ -97,8 +98,10 @@ struct ur_kernel_handle_t_ : ur_object {
9798
struct ArgumentInfo {
9899
uint32_t Index;
99100
size_t Size;
100-
// const ur_mem_handle_t_ *Value;
101-
ur_mem_handle_t_ *Value;
101+
// Value may be either a memory object or a raw pointer value (for pointer
102+
// arguments). Resolve at enqueue time per-device to ensure correct handle
103+
// is used for that device.
104+
std::variant<ur_mem_handle_t, const void *> Value;
102105
ur_mem_handle_t_::access_mode_t AccessMode{ur_mem_handle_t_::unknown};
103106
};
104107
// Arguments that still need to be set (with zeKernelSetArgumentValue)

0 commit comments

Comments
 (0)