Skip to content

Commit 9878cfa

Browse files
authored
Merge pull request #10671 from rakhmets/topic/cuda-ipc-rma
UCT/CUDA/CUDA_IPC: Set context associated with local buffer.
2 parents bc5f8e6 + 1684145 commit 9878cfa

File tree

8 files changed

+115
-155
lines changed

8 files changed

+115
-155
lines changed

src/uct/cuda/Makefile.am

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ noinst_HEADERS = \
2424
cuda_ipc/cuda_ipc_md.h \
2525
cuda_ipc/cuda_ipc_iface.h \
2626
cuda_ipc/cuda_ipc_ep.h \
27-
cuda_ipc/cuda_ipc_cache.h
27+
cuda_ipc/cuda_ipc_cache.h \
28+
cuda_ipc/cuda_ipc.inl
2829

2930
libuct_cuda_la_SOURCES = \
3031
base/cuda_iface.c \

src/uct/cuda/cuda_ipc/cuda_ipc.inl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/**
2+
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2025. ALL RIGHTS RESERVED.
3+
* See file LICENSE for terms.
4+
*/
5+
6+
#ifndef UCT_CUDA_IPC_INL
7+
#define UCT_CUDA_IPC_INL
8+
9+
#include <uct/cuda/base/cuda_iface.h>
10+
11+
#include <cuda.h>
12+
13+
static UCS_F_ALWAYS_INLINE ucs_status_t
14+
uct_cuda_ipc_check_and_push_ctx(CUdeviceptr address, CUdevice *cuda_device_p,
15+
int *is_ctx_pushed)
16+
{
17+
#define UCT_CUDA_IPC_NUM_ATTRS 2
18+
CUpointer_attribute attr_type[UCT_CUDA_IPC_NUM_ATTRS];
19+
void *attr_data[UCT_CUDA_IPC_NUM_ATTRS];
20+
CUcontext cuda_ctx, cuda_ctx_current;
21+
int cuda_device_ordinal;
22+
ucs_status_t status;
23+
CUdevice cuda_device;
24+
25+
attr_type[0] = CU_POINTER_ATTRIBUTE_CONTEXT;
26+
attr_data[0] = &cuda_ctx;
27+
attr_type[1] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
28+
attr_data[1] = &cuda_device_ordinal;
29+
30+
status = UCT_CUDADRV_FUNC_LOG_ERR(
31+
cuPointerGetAttributes(UCT_CUDA_IPC_NUM_ATTRS, attr_type, attr_data,
32+
address));
33+
if (ucs_unlikely(status != UCS_OK)) {
34+
return status;
35+
}
36+
37+
ucs_assertv(cuda_device_ordinal >= 0, "cuda_device_ordinal=%d",
38+
cuda_device_ordinal);
39+
40+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGet(&cuda_device,
41+
cuda_device_ordinal));
42+
if (ucs_unlikely(status != UCS_OK)) {
43+
return status;
44+
}
45+
46+
if (cuda_ctx == NULL) {
47+
status = uct_cuda_primary_ctx_retain(cuda_device, 0, &cuda_ctx);
48+
if (ucs_unlikely(status != UCS_OK)) {
49+
return status;
50+
}
51+
52+
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
53+
}
54+
55+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxGetCurrent(&cuda_ctx_current));
56+
if (ucs_unlikely(status != UCS_OK)) {
57+
return status;
58+
}
59+
60+
if (cuda_ctx != cuda_ctx_current) {
61+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(cuda_ctx));
62+
if (ucs_unlikely(status != UCS_OK)) {
63+
return status;
64+
}
65+
66+
*is_ctx_pushed = 1;
67+
} else {
68+
*is_ctx_pushed = 0;
69+
}
70+
71+
*cuda_device_p = cuda_device;
72+
return UCS_OK;
73+
}
74+
75+
static UCS_F_ALWAYS_INLINE void
76+
uct_cuda_ipc_check_and_pop_ctx(int is_ctx_pushed)
77+
{
78+
if (is_ctx_pushed) {
79+
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
80+
}
81+
}
82+
83+
#endif

src/uct/cuda/cuda_ipc/cuda_ipc_cache.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,14 @@ uct_cuda_ipc_close_memhandle_legacy(uct_cuda_ipc_cache_region_t *region)
141141
{
142142
ucs_status_t status;
143143

144-
status = uct_cuda_ipc_primary_ctx_retain_and_push(region->key.dev_num);
144+
status = uct_cuda_ipc_primary_ctx_retain_and_push(region->cu_dev);
145145
if (status != UCS_OK) {
146146
return status;
147147
}
148148

149149
status = UCT_CUDADRV_FUNC_LOG_WARN(
150150
cuIpcCloseMemHandle((CUdeviceptr)region->mapped_addr));
151-
uct_cuda_ipc_primary_ctx_pop_and_release(region->key.dev_num);
151+
uct_cuda_ipc_primary_ctx_pop_and_release(region->cu_dev);
152152
return status;
153153
}
154154

@@ -626,6 +626,7 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_ipc_map_memhandle,
626626
region->key = *key;
627627
region->mapped_addr = *mapped_addr;
628628
region->refcount = 1;
629+
region->cu_dev = cu_dev;
629630

630631
status = UCS_PROFILE_CALL(ucs_pgtable_insert,
631632
&cache->pgtable, &region->super);

src/uct/cuda/cuda_ipc/cuda_ipc_cache.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct uct_cuda_ipc_cache_region {
2626
uct_cuda_ipc_rkey_t key; /**< Remote memory key */
2727
void *mapped_addr; /**< Local mapped address */
2828
uint64_t refcount; /**< Track in-flight ops before unmapping*/
29+
CUdevice cu_dev; /**< CUDA device */
2930
};
3031

3132

src/uct/cuda/cuda_ipc/cuda_ipc_ep.c

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "cuda_ipc_ep.h"
1111
#include "cuda_ipc_iface.h"
1212
#include "cuda_ipc_md.h"
13+
#include "cuda_ipc.inl"
1314

1415
#include <uct/base/uct_log.h>
1516
#include <uct/base/uct_iov.inl>
@@ -64,49 +65,27 @@ uct_cuda_primary_ctx_pop_and_release(CUdevice cuda_device)
6465
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
6566
}
6667

67-
static UCS_F_ALWAYS_INLINE ucs_status_t
68-
uct_cuda_ipc_ctx_rsc_get(uct_cuda_ipc_iface_t *iface, CUdevice cuda_device,
69-
uct_cuda_ipc_ctx_rsc_t **ctx_rsc_p)
68+
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_ipc_ctx_rsc_get(
69+
uct_cuda_ipc_iface_t *iface, uct_cuda_ipc_ctx_rsc_t **ctx_rsc_p)
7070
{
7171
unsigned long long ctx_id;
7272
ucs_status_t status;
7373
CUresult result;
74-
CUcontext cuda_ctx;
7574
uct_cuda_ctx_rsc_t *ctx_rsc;
7675

77-
status = uct_cuda_primary_ctx_retain(cuda_device, 0, &cuda_ctx);
78-
if (ucs_unlikely(status != UCS_OK)) {
79-
goto err;
80-
}
81-
82-
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(cuda_ctx));
83-
if (ucs_unlikely(status != UCS_OK)) {
84-
/* To workaround gcc 4.8.5 compiler error */
85-
status = UCS_ERR_IO_ERROR;
86-
goto err_release;
87-
}
88-
8976
result = uct_cuda_base_ctx_get_id(NULL, &ctx_id);
9077
if (ucs_unlikely(result != CUDA_SUCCESS)) {
9178
UCT_CUDADRV_LOG(cuCtxGetId, UCS_LOG_LEVEL_ERROR, result);
92-
status = UCS_ERR_IO_ERROR;
93-
goto err_pop;
79+
return UCS_ERR_IO_ERROR;
9480
}
9581

9682
status = uct_cuda_base_ctx_rsc_get(&iface->super, ctx_id, &ctx_rsc);
9783
if (ucs_unlikely(status != UCS_OK)) {
98-
goto err_pop;
84+
return status;
9985
}
10086

10187
*ctx_rsc_p = ucs_derived_of(ctx_rsc, uct_cuda_ipc_ctx_rsc_t);
10288
return UCS_OK;
103-
104-
err_pop:
105-
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
106-
err_release:
107-
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
108-
err:
109-
return status;
11089
}
11190

11291
static UCS_F_ALWAYS_INLINE ucs_status_t
@@ -117,6 +96,8 @@ uct_cuda_ipc_post_cuda_async_copy(uct_ep_h tl_ep, uint64_t remote_addr,
11796
uct_cuda_ipc_iface_t *iface = ucs_derived_of(tl_ep->iface,
11897
uct_cuda_ipc_iface_t);
11998
uct_cuda_ipc_unpacked_rkey_t *key = (uct_cuda_ipc_unpacked_rkey_t *)rkey;
99+
CUdevice cuda_device;
100+
int is_ctx_pushed;
120101
void *mapped_rem_addr;
121102
void *mapped_addr;
122103
uct_cuda_ipc_event_desc_t *cuda_ipc_event;
@@ -133,15 +114,20 @@ uct_cuda_ipc_post_cuda_async_copy(uct_ep_h tl_ep, uint64_t remote_addr,
133114
return UCS_OK;
134115
}
135116

136-
status = uct_cuda_ipc_map_memhandle(&key->super, key->super.dev_num,
137-
&mapped_addr);
117+
status = uct_cuda_ipc_check_and_push_ctx((CUdeviceptr)iov[0].buffer,
118+
&cuda_device, &is_ctx_pushed);
138119
if (ucs_unlikely(status != UCS_OK)) {
139120
return status;
140121
}
141122

142-
status = uct_cuda_ipc_ctx_rsc_get(iface, key->super.dev_num, &ctx_rsc);
123+
status = uct_cuda_ipc_map_memhandle(&key->super, cuda_device, &mapped_addr);
143124
if (ucs_unlikely(status != UCS_OK)) {
144-
return status;
125+
goto out;
126+
}
127+
128+
status = uct_cuda_ipc_ctx_rsc_get(iface, &ctx_rsc);
129+
if (ucs_unlikely(status != UCS_OK)) {
130+
goto out;
145131
}
146132

147133
offset = (uintptr_t)remote_addr - (uintptr_t)key->super.d_bptr;
@@ -158,7 +144,7 @@ uct_cuda_ipc_post_cuda_async_copy(uct_ep_h tl_ep, uint64_t remote_addr,
158144

159145
if (ucs_unlikely(stream == NULL)) {
160146
ucs_error("stream=%d for dev_num=%d not available", key->stream_id,
161-
key->super.dev_num);
147+
cuda_device);
162148
status = UCS_ERR_IO_ERROR;
163149
goto out;
164150
}
@@ -198,13 +184,13 @@ uct_cuda_ipc_post_cuda_async_copy(uct_ep_h tl_ep, uint64_t remote_addr,
198184
cuda_ipc_event->mapped_addr = mapped_addr;
199185
cuda_ipc_event->d_bptr = (uintptr_t)key->super.d_bptr;
200186
cuda_ipc_event->pid = key->super.pid;
201-
cuda_ipc_event->cuda_device = key->super.dev_num;
187+
cuda_ipc_event->cuda_device = cuda_device;
202188
ucs_trace("cuMemcpyDtoDAsync issued :%p dst:%p, src:%p len:%ld",
203189
cuda_ipc_event, (void *) dst, (void *) src, iov[0].length);
204190
status = UCS_INPROGRESS;
205191

206192
out:
207-
uct_cuda_primary_ctx_pop_and_release(key->super.dev_num);
193+
uct_cuda_ipc_check_and_pop_ctx(is_ctx_pushed);
208194
return status;
209195
}
210196

src/uct/cuda/cuda_ipc/cuda_ipc_md.c

Lines changed: 7 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#include "cuda_ipc_md.h"
1111
#include "cuda_ipc_cache.h"
12-
12+
#include "cuda_ipc.inl"
1313
#include <string.h>
1414
#include <limits.h>
1515
#include <ucs/debug/log.h>
@@ -105,87 +105,13 @@ uct_cuda_ipc_md_query(uct_md_h md, uct_md_attr_v2_t *md_attr)
105105
return UCS_OK;
106106
}
107107

108-
static ucs_status_t
109-
uct_cuda_ipc_mem_reg_push_ctx(CUdeviceptr address, CUdevice *cuda_device_p,
110-
int *is_ctx_pushed, int *is_ctx_retained)
111-
{
112-
#define UCT_CUDA_IPC_NUM_ATTRS 2
113-
CUcontext cuda_curr_ctx, cuda_ctx;
114-
CUdevice cuda_device;
115-
CUpointer_attribute attr_type[UCT_CUDA_IPC_NUM_ATTRS];
116-
void *attr_data[UCT_CUDA_IPC_NUM_ATTRS];
117-
int cuda_device_ordinal;
118-
ucs_status_t status;
119-
120-
attr_type[0] = CU_POINTER_ATTRIBUTE_CONTEXT;
121-
attr_data[0] = &cuda_ctx;
122-
attr_type[1] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
123-
attr_data[1] = &cuda_device_ordinal;
124-
125-
status = UCT_CUDADRV_FUNC_LOG_ERR(
126-
cuPointerGetAttributes(UCT_CUDA_IPC_NUM_ATTRS, attr_type, attr_data,
127-
address));
128-
if (status != UCS_OK) {
129-
return status;
130-
}
131-
132-
ucs_assertv(cuda_device_ordinal >= 0, "cuda_device_ordinal=%d",
133-
cuda_device_ordinal);
134-
135-
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGet(&cuda_device,
136-
cuda_device_ordinal));
137-
if (status != UCS_OK) {
138-
return status;
139-
}
140-
141-
*is_ctx_pushed = 0;
142-
*cuda_device_p = cuda_device;
143-
144-
if (cuda_ctx == NULL) {
145-
status = uct_cuda_primary_ctx_retain(*cuda_device_p, 0, &cuda_ctx);
146-
if (status != UCS_OK) {
147-
return status;
148-
}
149-
150-
*is_ctx_retained = 1;
151-
} else {
152-
*is_ctx_retained = 0;
153-
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxGetCurrent(&cuda_curr_ctx));
154-
if ((status != UCS_OK) || (cuda_curr_ctx == cuda_ctx)) {
155-
/* Failed to get current context or the pointer's context is
156-
* already current, no need to push/pop */
157-
return status;
158-
}
159-
}
160-
161-
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(cuda_ctx));
162-
if (status != UCS_OK) {
163-
if (*is_ctx_retained) {
164-
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(*cuda_device_p));
165-
}
166-
return status;
167-
}
168-
169-
*is_ctx_pushed = 1;
170-
return UCS_OK;
171-
}
172-
173-
static void uct_cuda_ipc_mem_reg_pop_ctx(CUdevice cuda_device,
174-
int is_ctx_retained)
175-
{
176-
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
177-
if (is_ctx_retained) {
178-
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
179-
}
180-
}
181-
182108
static ucs_status_t
183109
uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
184110
uct_cuda_ipc_lkey_t **key_p)
185111
{
186112
uct_cuda_ipc_lkey_t *key;
187113
ucs_status_t status;
188-
int is_ctx_pushed, is_ctx_retained;
114+
int is_ctx_pushed;
189115
CUdevice cuda_device;
190116
#if HAVE_CUDA_FABRIC
191117
#define UCT_CUDA_IPC_QUERY_NUM_ATTRS 3
@@ -202,8 +128,8 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
202128
return UCS_ERR_NO_MEMORY;
203129
}
204130

205-
status = uct_cuda_ipc_mem_reg_push_ctx((CUdeviceptr)addr, &cuda_device,
206-
&is_ctx_pushed, &is_ctx_retained);
131+
status = uct_cuda_ipc_check_and_push_ctx((CUdeviceptr)addr, &cuda_device,
132+
&is_ctx_pushed);
207133
if (status != UCS_OK) {
208134
goto out;
209135
}
@@ -313,17 +239,15 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
313239
ucs_list_add_tail(&memh->list, &key->link);
314240
ucs_trace("registered addr:%p/%p length:%zd type:%u dev_num:%d "
315241
"buffer_id:%llu",
316-
addr, (void *)key->d_bptr, key->b_len, key->ph.handle_type,
317-
memh->dev_num, key->ph.buffer_id);
242+
addr, (void*)key->d_bptr, key->b_len, key->ph.handle_type,
243+
cuda_device, key->ph.buffer_id);
318244

319245
memh->dev_num = cuda_device;
320246
*key_p = key;
321247
status = UCS_OK;
322248

323249
out_pop_ctx:
324-
if (is_ctx_pushed) {
325-
uct_cuda_ipc_mem_reg_pop_ctx(memh->dev_num, is_ctx_retained);
326-
}
250+
uct_cuda_ipc_check_and_pop_ctx(is_ctx_pushed);
327251
out:
328252
if (status != UCS_OK) {
329253
ucs_free(key);
@@ -394,10 +318,6 @@ uct_cuda_ipc_is_peer_accessible(uct_cuda_ipc_component_t *component,
394318
}
395319
}
396320

397-
/* Save local device number, so we use it to find remote rcache when mapping
398-
* mem_handle in uct_cuda_ipc_post_cuda_async_copy */
399-
rkey->super.dev_num = cu_dev;
400-
401321
pthread_mutex_lock(&component->lock);
402322

403323
cache = uct_cuda_ipc_get_dev_cache(component, &rkey->super);

src/uct/cuda/cuda_ipc/cuda_ipc_md.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ typedef struct {
135135
pid_t pid; /* PID as key to resolve peer_map hash */
136136
CUdeviceptr d_bptr; /* Allocation base address */
137137
size_t b_len; /* Allocation size */
138-
int dev_num; /* GPU Device number */
139138
CUuuid uuid; /* GPU Device UUID */
140139
} uct_cuda_ipc_rkey_t;
141140

0 commit comments

Comments
 (0)