Skip to content

Commit 4a03d17

Browse files
committed
UCT/CUDA/CUDA_COPY: Fixed sending memory allocated on user context.
1 parent 0210ad6 commit 4a03d17

File tree

2 files changed

+183
-60
lines changed

2 files changed

+183
-60
lines changed

src/uct/cuda/cuda_copy/cuda_copy_ep.c

Lines changed: 121 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
#include <ucs/type/class.h>
2121
#include <ucs/memory/memtype_cache.h>
2222

23+
typedef struct {
24+
ucs_memory_type_t src_type;
25+
ucs_memory_type_t dst_type;
26+
CUdevice cuda_device;
27+
CUcontext cuda_context;
28+
uct_cuda_copy_ctx_rsc_t *ctx_rsc;
29+
} uct_cuda_copy_ep_rma_ctx_t;
2330

2431
static UCS_CLASS_INIT_FUNC(uct_cuda_copy_ep_t, const uct_ep_params_t *params)
2532
{
@@ -65,7 +72,7 @@ uct_cuda_copy_get_stream(uct_cuda_copy_ctx_rsc_t *ctx_rsc,
6572
}
6673

6774
static UCS_F_ALWAYS_INLINE ucs_memory_type_t
68-
uct_cuda_copy_get_mem_type(uct_md_h md, void *address, size_t length,
75+
uct_cuda_copy_get_mem_type(uct_md_h md, const void *address, size_t length,
6976
ucs_sys_device_t *sys_dev)
7077
{
7178
ucs_memory_info_t mem_info;
@@ -100,17 +107,23 @@ uct_cuda_copy_get_mem_type(uct_md_h md, void *address, size_t length,
100107
}
101108

102109
static UCS_F_ALWAYS_INLINE void
103-
uct_cuda_copy_get_mem_types(uct_md_h md, void *src, void *dst, size_t length,
104-
ucs_memory_type_t *src_mem_type_p,
110+
uct_cuda_copy_get_mem_types(uct_md_h md, const void *src, const void *dst,
111+
size_t length, ucs_memory_type_t *src_mem_type_p,
105112
ucs_memory_type_t *dst_mem_type_p,
106-
ucs_sys_device_t *sys_dev_p)
113+
ucs_sys_device_t *sys_dev_p,
114+
CUdeviceptr *cuda_deviceptr_p)
107115
{
108116
ucs_sys_device_t src_sys_dev, dst_sys_dev;
109117

110118
*src_mem_type_p = uct_cuda_copy_get_mem_type(md, src, length, &src_sys_dev);
111119
*dst_mem_type_p = uct_cuda_copy_get_mem_type(md, dst, length, &dst_sys_dev);
112-
*sys_dev_p = (src_sys_dev != UCS_SYS_DEVICE_ID_UNKNOWN) ?
113-
src_sys_dev : dst_sys_dev;
120+
if (src_sys_dev != UCS_SYS_DEVICE_ID_UNKNOWN) {
121+
*sys_dev_p = src_sys_dev;
122+
*cuda_deviceptr_p = (CUdeviceptr)src;
123+
} else {
124+
*sys_dev_p = dst_sys_dev;
125+
*cuda_deviceptr_p = (CUdeviceptr)dst;
126+
}
114127

115128
ucs_assertv((src_sys_dev == dst_sys_dev) ||
116129
(src_sys_dev == UCS_SYS_DEVICE_ID_UNKNOWN) ||
@@ -163,21 +176,24 @@ uct_cuda_primary_ctx_push_first_active(CUdevice *cuda_device_p)
163176
}
164177

165178
static UCS_F_ALWAYS_INLINE void
166-
uct_cuda_primary_ctx_pop_and_release(CUdevice cuda_device)
179+
uct_cuda_primary_ctx_pop_and_release(CUdevice cuda_device,
180+
CUcontext cuda_context)
167181
{
168-
if (ucs_likely(cuda_device == CU_DEVICE_INVALID)) {
169-
return;
182+
if ((cuda_device != CU_DEVICE_INVALID) || (cuda_context != NULL)) {
183+
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
170184
}
171185

172-
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
173-
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
186+
if (cuda_device != CU_DEVICE_INVALID) {
187+
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
188+
}
174189
}
175190

176-
static UCS_F_ALWAYS_INLINE ucs_status_t
177-
uct_cuda_copy_ctx_rsc_get(uct_cuda_copy_iface_t *iface,
178-
ucs_sys_device_t sys_dev, CUdevice *cuda_device_p,
179-
uct_cuda_copy_ctx_rsc_t **ctx_rsc_p)
191+
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ctx_rsc_get(
192+
uct_cuda_copy_iface_t *iface, ucs_sys_device_t sys_dev,
193+
CUdeviceptr cuda_deviceptr, CUdevice *cuda_device_p,
194+
CUcontext *cuda_context_p, uct_cuda_copy_ctx_rsc_t **ctx_rsc_p)
180195
{
196+
CUcontext cuda_context = NULL;
181197
unsigned long long ctx_id;
182198
CUresult result;
183199
CUdevice cuda_device;
@@ -195,7 +211,35 @@ uct_cuda_copy_ctx_rsc_get(uct_cuda_copy_iface_t *iface,
195211
}
196212

197213
status = uct_cuda_copy_push_ctx(cuda_device, 0, UCS_LOG_LEVEL_ERROR);
198-
if (status != UCS_OK) {
214+
if (ucs_unlikely(status == UCS_ERR_NO_DEVICE)) {
215+
/* Device primary context of `cuda_device` is inactive. The memory
216+
* was probably allocated on the context created with cuCtxCreate.
217+
* Fallback to query context based on memory address. */
218+
status = UCT_CUDADRV_FUNC_LOG_ERR(
219+
cuPointerGetAttribute(&cuda_context,
220+
CU_POINTER_ATTRIBUTE_CONTEXT,
221+
cuda_deviceptr));
222+
if (status != UCS_OK) {
223+
goto err;
224+
}
225+
226+
if (cuda_context == NULL) {
227+
ucs_error("failed to query cuda context for %p allocated on "
228+
"GPU%d",
229+
(void*)cuda_deviceptr, cuda_device);
230+
status = UCS_ERR_UNSUPPORTED;
231+
goto err;
232+
}
233+
234+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(cuda_context));
235+
if (status != UCS_OK) {
236+
goto err;
237+
}
238+
239+
cuda_device = CU_DEVICE_INVALID;
240+
}
241+
242+
if (ucs_unlikely(status != UCS_OK)) {
199243
goto err;
200244
}
201245
} else {
@@ -236,57 +280,83 @@ uct_cuda_copy_ctx_rsc_get(uct_cuda_copy_iface_t *iface,
236280
goto err_pop_and_release;
237281
}
238282

239-
*cuda_device_p = cuda_device;
240-
*ctx_rsc_p = ucs_derived_of(ctx_rsc, uct_cuda_copy_ctx_rsc_t);
283+
*cuda_device_p = cuda_device;
284+
*cuda_context_p = cuda_context;
285+
*ctx_rsc_p = ucs_derived_of(ctx_rsc, uct_cuda_copy_ctx_rsc_t);
241286
return UCS_OK;
242287

243288
err_pop_and_release:
244-
uct_cuda_primary_ctx_pop_and_release(cuda_device);
289+
uct_cuda_primary_ctx_pop_and_release(cuda_device, cuda_context);
245290
err:
246291
return status;
247292
}
248293

294+
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_get_rma_ctx(
295+
uct_cuda_copy_iface_t *iface, const void *src, const void *dst,
296+
size_t length, uct_cuda_copy_ep_rma_ctx_t *rma_ctx_p)
297+
{
298+
ucs_memory_type_t src_type;
299+
ucs_memory_type_t dst_type;
300+
ucs_sys_device_t sys_dev;
301+
CUdeviceptr cuda_deviceptr;
302+
CUdevice cuda_device;
303+
CUcontext cuda_context;
304+
uct_cuda_copy_ctx_rsc_t *ctx_rsc;
305+
ucs_status_t status;
306+
307+
uct_cuda_copy_get_mem_types(iface->super.super.md, src, dst, length,
308+
&src_type, &dst_type, &sys_dev,
309+
&cuda_deviceptr);
310+
311+
status = uct_cuda_copy_ctx_rsc_get(iface, sys_dev, cuda_deviceptr,
312+
&cuda_device, &cuda_context, &ctx_rsc);
313+
if (ucs_unlikely(status != UCS_OK)) {
314+
return status;
315+
}
316+
317+
rma_ctx_p->src_type = src_type;
318+
rma_ctx_p->dst_type = dst_type;
319+
rma_ctx_p->cuda_device = cuda_device;
320+
rma_ctx_p->cuda_context = cuda_context;
321+
rma_ctx_p->ctx_rsc = ctx_rsc;
322+
return UCS_OK;
323+
}
324+
249325
static UCS_F_ALWAYS_INLINE ucs_status_t
250326
uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
251327
size_t length, uct_completion_t *comp)
252328
{
253-
uct_cuda_copy_iface_t *iface = ucs_derived_of(tl_ep->iface, uct_cuda_copy_iface_t);
254-
uct_base_iface_t *base_iface = ucs_derived_of(tl_ep->iface, uct_base_iface_t);
255-
CUdevice cuda_device;
256-
uct_cuda_event_desc_t *cuda_event;
257-
uct_cuda_queue_desc_t *q_desc;
329+
uct_cuda_copy_iface_t *iface = ucs_derived_of(tl_ep->iface,
330+
uct_cuda_copy_iface_t);
331+
uct_cuda_copy_ep_rma_ctx_t rma_ctx;
258332
ucs_status_t status;
259-
ucs_memory_type_t src_type;
260-
ucs_memory_type_t dst_type;
261-
CUstream *stream;
333+
uct_cuda_queue_desc_t *q_desc;
262334
ucs_queue_head_t *event_q;
263-
uct_cuda_copy_ctx_rsc_t *ctx_rsc;
264-
ucs_sys_device_t sys_dev;
335+
CUstream *stream;
336+
uct_cuda_event_desc_t *cuda_event;
265337

266338
if (!length) {
267339
return UCS_OK;
268340
}
269341

270-
uct_cuda_copy_get_mem_types(base_iface->md, src, dst, length, &src_type,
271-
&dst_type, &sys_dev);
272-
273-
status = uct_cuda_copy_ctx_rsc_get(iface, sys_dev, &cuda_device, &ctx_rsc);
342+
status = uct_cuda_copy_ep_get_rma_ctx(iface, src, dst, length, &rma_ctx);
274343
if (ucs_unlikely(status != UCS_OK)) {
275344
goto out;
276345
}
277346

278-
q_desc = &ctx_rsc->queue_desc[src_type][dst_type];
279-
event_q = &q_desc->event_queue;
280-
stream = uct_cuda_copy_get_stream(ctx_rsc, src_type, dst_type);
347+
q_desc = &rma_ctx.ctx_rsc->queue_desc[rma_ctx.src_type][rma_ctx.dst_type];
348+
event_q = &q_desc->event_queue;
349+
stream = uct_cuda_copy_get_stream(rma_ctx.ctx_rsc, rma_ctx.src_type,
350+
rma_ctx.dst_type);
281351
if (ucs_unlikely(stream == NULL)) {
282352
ucs_error("stream for src %s dst %s not available",
283-
ucs_memory_type_names[src_type],
284-
ucs_memory_type_names[dst_type]);
353+
ucs_memory_type_names[rma_ctx.src_type],
354+
ucs_memory_type_names[rma_ctx.dst_type]);
285355
status = UCS_ERR_IO_ERROR;
286356
goto out_pop_and_release;
287357
}
288358

289-
cuda_event = ucs_mpool_get(&ctx_rsc->super.event_mp);
359+
cuda_event = ucs_mpool_get(&rma_ctx.ctx_rsc->super.event_mp);
290360
if (ucs_unlikely(cuda_event == NULL)) {
291361
ucs_error("failed to allocate cuda event object");
292362
status = UCS_ERR_NO_MEMORY;
@@ -313,15 +383,17 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
313383
cuda_event->comp = comp;
314384

315385
UCS_STATIC_BITMAP_SET(&iface->streams_to_sync,
316-
uct_cuda_copy_flush_bitmap_idx(src_type, dst_type));
386+
uct_cuda_copy_flush_bitmap_idx(rma_ctx.src_type,
387+
rma_ctx.dst_type));
317388

318389
ucs_trace("cuda async issued: %p dst:%p[%s], src:%p[%s] len:%ld",
319-
cuda_event, dst, ucs_memory_type_names[dst_type], src,
320-
ucs_memory_type_names[src_type], length);
390+
cuda_event, dst, ucs_memory_type_names[rma_ctx.dst_type], src,
391+
ucs_memory_type_names[rma_ctx.src_type], length);
321392
status = UCS_INPROGRESS;
322393

323394
out_pop_and_release:
324-
uct_cuda_primary_ctx_pop_and_release(cuda_device);
395+
uct_cuda_primary_ctx_pop_and_release(rma_ctx.cuda_device,
396+
rma_ctx.cuda_context);
325397
out:
326398
return status;
327399
err_mpool_put:
@@ -368,31 +440,24 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_copy_ep_put_zcopy,
368440
uct_iov_total_length(iov, iovcnt));
369441
uct_cuda_copy_trace_data("PUT_ZCOPY", remote_addr, iov, iovcnt);
370442
return status;
371-
372443
}
373444

374445
static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_rma_short(
375446
uct_ep_h tl_ep, CUdeviceptr dst, CUdeviceptr src, unsigned length)
376447
{
377448
uct_cuda_copy_iface_t *iface = ucs_derived_of(tl_ep->iface,
378449
uct_cuda_copy_iface_t);
379-
CUdevice cuda_device;
380-
uct_cuda_copy_ctx_rsc_t *ctx_rsc;
450+
uct_cuda_copy_ep_rma_ctx_t rma_ctx;
381451
ucs_status_t status;
382-
ucs_memory_type_t src_type;
383-
ucs_memory_type_t dst_type;
384-
ucs_sys_device_t sys_dev;
385452
CUstream *stream;
386453

387-
uct_cuda_copy_get_mem_types(iface->super.super.md, (void*)src, (void*)dst,
388-
length, &src_type, &dst_type, &sys_dev);
389-
390-
status = uct_cuda_copy_ctx_rsc_get(iface, sys_dev, &cuda_device, &ctx_rsc);
454+
status = uct_cuda_copy_ep_get_rma_ctx(iface, (void*)src, (void*)dst, length,
455+
&rma_ctx);
391456
if (ucs_unlikely(status != UCS_OK)) {
392457
goto out;
393458
}
394459

395-
stream = &ctx_rsc->short_stream;
460+
stream = &rma_ctx.ctx_rsc->short_stream;
396461
status = uct_cuda_base_init_stream(stream);
397462
if (ucs_unlikely(status != UCS_OK)) {
398463
goto out_pop_and_release;
@@ -406,7 +471,8 @@ static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_rma_short(
406471
status = UCT_CUDADRV_FUNC_LOG_ERR(cuStreamSynchronize(*stream));
407472

408473
out_pop_and_release:
409-
uct_cuda_primary_ctx_pop_and_release(cuda_device);
474+
uct_cuda_primary_ctx_pop_and_release(rma_ctx.cuda_device,
475+
rma_ctx.cuda_context);
410476
out:
411477
return status;
412478
}

test/gtest/uct/cuda/test_switch_cuda_device.cc

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,66 @@ _UCT_MD_INSTANTIATE_TEST_CASE(test_switch_cuda_device, cuda_cpy);
180180

181181
class test_p2p_create_destroy_ctx : public uct_p2p_rma_test {
182182
public:
183+
void cleanup() override;
183184
void test_xfer(send_func_t send, size_t length, unsigned flags,
184185
ucs_memory_type_t mem_type) override;
186+
187+
private:
188+
CUcontext m_cuda_context = nullptr;
185189
};
186190

191+
void test_p2p_create_destroy_ctx::cleanup()
192+
{
193+
uct_p2p_rma_test::cleanup();
194+
if (m_cuda_context != nullptr) {
195+
EXPECT_EQ(cuCtxDestroy(m_cuda_context), CUDA_SUCCESS);
196+
}
197+
}
198+
199+
namespace {
200+
CUresult deactivate_primary_cuda_context(CUdevice cuda_device)
201+
{
202+
for (;;) {
203+
unsigned ctx_flags;
204+
int active;
205+
auto cuda_result = cuDevicePrimaryCtxGetState(cuda_device, &ctx_flags,
206+
&active);
207+
if (cuda_result != CUDA_SUCCESS) {
208+
return cuda_result;
209+
}
210+
211+
if (active == 0) {
212+
return CUDA_SUCCESS;
213+
}
214+
215+
cuda_result = cuDevicePrimaryCtxRelease(cuda_device);
216+
if (cuda_result != CUDA_SUCCESS) {
217+
return cuda_result;
218+
}
219+
}
220+
} // namespace
221+
222+
CUresult clear_cuda_context_stack()
223+
{
224+
for (;;) {
225+
CUcontext cuda_context;
226+
auto cuda_result = cuCtxGetCurrent(&cuda_context);
227+
if (cuda_result != CUDA_SUCCESS) {
228+
return cuda_result;
229+
}
230+
231+
if (cuda_context == nullptr) {
232+
return CUDA_SUCCESS;
233+
}
234+
235+
cuda_result = cuCtxPopCurrent(NULL);
236+
if (cuda_result != CUDA_SUCCESS) {
237+
return cuda_result;
238+
}
239+
}
240+
}
241+
}
242+
187243
void test_p2p_create_destroy_ctx::test_xfer(send_func_t send, size_t length,
188244
unsigned flags,
189245
ucs_memory_type_t mem_type)
@@ -197,16 +253,17 @@ void test_p2p_create_destroy_ctx::test_xfer(send_func_t send, size_t length,
197253

198254
CUdevice device;
199255
ASSERT_EQ(cuDeviceGet(&device, 0), CUDA_SUCCESS);
256+
ASSERT_EQ(deactivate_primary_cuda_context(device), CUDA_SUCCESS);
257+
ASSERT_EQ(clear_cuda_context_stack(), CUDA_SUCCESS);
200258

201-
CUcontext ctx;
202-
#if CUDA_VERSION >= 12050
259+
#if CUDA_VERSION >= 13000
203260
CUctxCreateParams ctx_create_params = {};
204-
ASSERT_EQ(cuCtxCreate_v4(&ctx, &ctx_create_params, 0, device), CUDA_SUCCESS);
261+
ASSERT_EQ(cuCtxCreate(&m_cuda_context, &ctx_create_params, 0, device),
262+
CUDA_SUCCESS);
205263
#else
206-
ASSERT_EQ(cuCtxCreate(&ctx, 0, device), CUDA_SUCCESS);
264+
ASSERT_EQ(cuCtxCreate(&m_cuda_context, 0, device), CUDA_SUCCESS);
207265
#endif
208266
uct_p2p_rma_test::test_xfer(send, length, flags, mem_type);
209-
EXPECT_EQ(cuCtxDestroy(ctx), CUDA_SUCCESS);
210267
}
211268

212269
UCS_TEST_P(test_p2p_create_destroy_ctx, put_short)

0 commit comments

Comments
 (0)