Skip to content

Commit c081695

Browse files
committed
[UR][Offload] Use new olMemInfo systems
We no longer track allocations, and `urMemGetInfo` is implemented fully.
1 parent ec4e5d1 commit c081695

File tree

3 files changed

+89
-39
lines changed

3 files changed

+89
-39
lines changed

unified-runtime/source/adapters/offload/context.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,14 @@ struct ur_context_handle_t_ : RefCounted {
2929
~ur_context_handle_t_() { urDeviceRelease(Device); }
3030

3131
ur_device_handle_t Device;
32-
std::unordered_map<void *, alloc_info_t> AllocTypeMap;
3332

34-
std::optional<alloc_info_t> getAllocType(const void *UsmPtr) {
35-
for (auto &pair : AllocTypeMap) {
36-
if (UsmPtr >= pair.first &&
37-
reinterpret_cast<uintptr_t>(UsmPtr) <
38-
reinterpret_cast<uintptr_t>(pair.first) + pair.second.Size) {
39-
return pair.second;
40-
}
33+
ol_result_t getAllocType(const void *UsmPtr, ol_alloc_type_t &Type) {
34+
auto Err = olGetMemInfo(UsmPtr, OL_MEM_INFO_TYPE, sizeof(Type), &Type);
35+
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
36+
// Treat unknown allocations as host
37+
Type = OL_ALLOC_TYPE_HOST;
38+
return OL_SUCCESS;
4139
}
42-
return std::nullopt;
40+
return Err;
4341
}
4442
};

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -440,17 +440,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
440440
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
441441
size_t size, uint32_t numEventsInWaitList,
442442
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
443-
auto GetDevice = [&](const void *Ptr) {
444-
auto Res = hQueue->UrContext->getAllocType(Ptr);
445-
if (!Res)
446-
return Adapter->HostDevice;
447-
return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice
448-
: hQueue->OffloadDevice;
449-
};
450-
451-
return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc,
452-
GetDevice(pSrc), size, blocking, numEventsInWaitList,
453-
phEventWaitList, phEvent);
443+
ol_alloc_type_t DstTy;
444+
OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pDst, DstTy));
445+
ol_device_handle_t Dst =
446+
DstTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice;
447+
448+
ol_alloc_type_t SrcTy;
449+
OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pSrc, SrcTy));
450+
ol_device_handle_t Src =
451+
SrcTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice;
452+
453+
return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, Dst, pSrc, Src, size,
454+
blocking, numEventsInWaitList, phEventWaitList, phEvent);
454455

455456
return UR_RESULT_SUCCESS;
456457
}

unified-runtime/source/adapters/offload/usm.cpp

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
2222
size_t size, void **ppMem) {
2323
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
2424
OL_ALLOC_TYPE_HOST, size, ppMem));
25-
26-
hContext->AllocTypeMap.insert_or_assign(
27-
*ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size});
2825
return UR_RESULT_SUCCESS;
2926
}
3027

@@ -33,9 +30,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
3330
ur_usm_pool_handle_t, size_t size, void **ppMem) {
3431
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
3532
OL_ALLOC_TYPE_DEVICE, size, ppMem));
36-
37-
hContext->AllocTypeMap.insert_or_assign(
38-
*ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size});
3933
return UR_RESULT_SUCCESS;
4034
}
4135

@@ -44,23 +38,80 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
4438
ur_usm_pool_handle_t, size_t size, void **ppMem) {
4539
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
4640
OL_ALLOC_TYPE_MANAGED, size, ppMem));
47-
48-
hContext->AllocTypeMap.insert_or_assign(
49-
*ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size});
5041
return UR_RESULT_SUCCESS;
5142
}
5243

53-
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
54-
void *pMem) {
55-
hContext->AllocTypeMap.erase(pMem);
44+
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t, void *pMem) {
5645
return offloadResultToUR(olMemFree(pMem));
5746
}
5847

59-
UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
60-
[[maybe_unused]] ur_context_handle_t hContext,
61-
[[maybe_unused]] const void *pMem,
62-
[[maybe_unused]] ur_usm_alloc_info_t propName,
63-
[[maybe_unused]] size_t propSize, [[maybe_unused]] void *pPropValue,
64-
[[maybe_unused]] size_t *pPropSizeRet) {
65-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
48+
UR_APIEXPORT ur_result_t UR_APICALL
49+
urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
50+
ur_usm_alloc_info_t propName, size_t propSize,
51+
void *pPropValue, size_t *pPropSizeRet) {
52+
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
53+
54+
ol_mem_info_t olInfo;
55+
56+
switch (propName) {
57+
case UR_USM_ALLOC_INFO_TYPE:
58+
olInfo = OL_MEM_INFO_TYPE;
59+
break;
60+
case UR_USM_ALLOC_INFO_BASE_PTR:
61+
olInfo = OL_MEM_INFO_BASE;
62+
break;
63+
case UR_USM_ALLOC_INFO_SIZE:
64+
olInfo = OL_MEM_INFO_SIZE;
65+
break;
66+
case UR_USM_ALLOC_INFO_DEVICE:
67+
// Contexts can only contain one device
68+
return ReturnValue(hContext->Device);
69+
case UR_USM_ALLOC_INFO_POOL:
70+
default:
71+
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
72+
break;
73+
}
74+
75+
if (pPropSizeRet) {
76+
OL_RETURN_ON_ERR(olGetMemInfoSize(pMem, olInfo, pPropSizeRet));
77+
}
78+
79+
if (pPropValue) {
80+
auto Err = olGetMemInfo(pMem, olInfo, propSize, pPropValue);
81+
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
82+
// If the device didn't allocate this object, return default values
83+
switch (propName) {
84+
case UR_USM_ALLOC_INFO_TYPE:
85+
return ReturnValue(UR_USM_TYPE_UNKNOWN);
86+
case UR_USM_ALLOC_INFO_BASE_PTR:
87+
return ReturnValue(nullptr);
88+
case UR_USM_ALLOC_INFO_SIZE:
89+
return ReturnValue(0);
90+
default:
91+
return UR_RESULT_ERROR_UNKNOWN;
92+
}
93+
}
94+
OL_RETURN_ON_ERR(Err);
95+
96+
if (propName == UR_USM_ALLOC_INFO_TYPE) {
97+
auto *OlType = reinterpret_cast<ol_alloc_type_t *>(pPropValue);
98+
auto *UrType = reinterpret_cast<ur_usm_type_t *>(pPropValue);
99+
switch (*OlType) {
100+
case OL_ALLOC_TYPE_HOST:
101+
*UrType = UR_USM_TYPE_HOST;
102+
break;
103+
case OL_ALLOC_TYPE_DEVICE:
104+
*UrType = UR_USM_TYPE_DEVICE;
105+
break;
106+
case OL_ALLOC_TYPE_MANAGED:
107+
*UrType = UR_USM_TYPE_SHARED;
108+
break;
109+
default:
110+
*UrType = UR_USM_TYPE_UNKNOWN;
111+
break;
112+
}
113+
}
114+
}
115+
116+
return UR_RESULT_SUCCESS;
66117
}

0 commit comments

Comments
 (0)