2020#include <ucs/type/class.h>
2121#include <ucs/memory/memtype_cache.h>
2222
23+ typedef struct {
24+ ucs_memory_type_t src_type ;
25+ ucs_memory_type_t dst_type ;
26+ CUdevice cuda_device ;
27+ CUcontext cuda_context ;
28+ uct_cuda_copy_ctx_rsc_t * ctx_rsc ;
29+ } uct_cuda_copy_ep_rma_ctx_t ;
2330
2431static UCS_CLASS_INIT_FUNC (uct_cuda_copy_ep_t , const uct_ep_params_t * params )
2532{
@@ -65,7 +72,7 @@ uct_cuda_copy_get_stream(uct_cuda_copy_ctx_rsc_t *ctx_rsc,
6572}
6673
6774static UCS_F_ALWAYS_INLINE ucs_memory_type_t
68- uct_cuda_copy_get_mem_type (uct_md_h md , void * address , size_t length ,
75+ uct_cuda_copy_get_mem_type (uct_md_h md , const void * address , size_t length ,
6976 ucs_sys_device_t * sys_dev )
7077{
7178 ucs_memory_info_t mem_info ;
@@ -100,17 +107,23 @@ uct_cuda_copy_get_mem_type(uct_md_h md, void *address, size_t length,
100107}
101108
102109static UCS_F_ALWAYS_INLINE void
103- uct_cuda_copy_get_mem_types (uct_md_h md , void * src , void * dst , size_t length ,
104- ucs_memory_type_t * src_mem_type_p ,
110+ uct_cuda_copy_get_mem_types (uct_md_h md , const void * src , const void * dst ,
111+ size_t length , ucs_memory_type_t * src_mem_type_p ,
105112 ucs_memory_type_t * dst_mem_type_p ,
106- ucs_sys_device_t * sys_dev_p )
113+ ucs_sys_device_t * sys_dev_p ,
114+ CUdeviceptr * cuda_deviceptr_p )
107115{
108116 ucs_sys_device_t src_sys_dev , dst_sys_dev ;
109117
110118 * src_mem_type_p = uct_cuda_copy_get_mem_type (md , src , length , & src_sys_dev );
111119 * dst_mem_type_p = uct_cuda_copy_get_mem_type (md , dst , length , & dst_sys_dev );
112- * sys_dev_p = (src_sys_dev != UCS_SYS_DEVICE_ID_UNKNOWN ) ?
113- src_sys_dev : dst_sys_dev ;
120+ if (src_sys_dev != UCS_SYS_DEVICE_ID_UNKNOWN ) {
121+ * sys_dev_p = src_sys_dev ;
122+ * cuda_deviceptr_p = (CUdeviceptr )src ;
123+ } else {
124+ * sys_dev_p = dst_sys_dev ;
125+ * cuda_deviceptr_p = (CUdeviceptr )dst ;
126+ }
114127
115128 ucs_assertv ((src_sys_dev == dst_sys_dev ) ||
116129 (src_sys_dev == UCS_SYS_DEVICE_ID_UNKNOWN ) ||
@@ -163,21 +176,24 @@ uct_cuda_primary_ctx_push_first_active(CUdevice *cuda_device_p)
163176}
164177
165178static UCS_F_ALWAYS_INLINE void
166- uct_cuda_primary_ctx_pop_and_release (CUdevice cuda_device )
179+ uct_cuda_primary_ctx_pop_and_release (CUdevice cuda_device ,
180+ CUcontext cuda_context )
167181{
168- if (ucs_likely (cuda_device == CU_DEVICE_INVALID )) {
169- return ;
182+ if ((cuda_device != CU_DEVICE_INVALID ) || ( cuda_context != NULL )) {
183+ UCT_CUDADRV_FUNC_LOG_WARN ( cuCtxPopCurrent ( NULL )) ;
170184 }
171185
172- UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
173- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_device ));
186+ if (cuda_device != CU_DEVICE_INVALID ) {
187+ UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_device ));
188+ }
174189}
175190
176- static UCS_F_ALWAYS_INLINE ucs_status_t
177- uct_cuda_copy_ctx_rsc_get ( uct_cuda_copy_iface_t * iface ,
178- ucs_sys_device_t sys_dev , CUdevice * cuda_device_p ,
179- uct_cuda_copy_ctx_rsc_t * * ctx_rsc_p )
191+ static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ctx_rsc_get (
192+ uct_cuda_copy_iface_t * iface , ucs_sys_device_t sys_dev ,
193+ CUdeviceptr cuda_deviceptr , CUdevice * cuda_device_p ,
194+ CUcontext * cuda_context_p , uct_cuda_copy_ctx_rsc_t * * ctx_rsc_p )
180195{
196+ CUcontext cuda_context = NULL ;
181197 unsigned long long ctx_id ;
182198 CUresult result ;
183199 CUdevice cuda_device ;
@@ -195,7 +211,35 @@ uct_cuda_copy_ctx_rsc_get(uct_cuda_copy_iface_t *iface,
195211 }
196212
197213 status = uct_cuda_copy_push_ctx (cuda_device , 0 , UCS_LOG_LEVEL_ERROR );
198- if (status != UCS_OK ) {
214+ if (ucs_unlikely (status == UCS_ERR_NO_DEVICE )) {
215+ /* Device primary context of `cuda_device` is inactive. The memory
216+ * was probably allocated on the context created with cuCtxCreate.
217+ * Fallback to query context based on memory address. */
218+ status = UCT_CUDADRV_FUNC_LOG_ERR (
219+ cuPointerGetAttribute (& cuda_context ,
220+ CU_POINTER_ATTRIBUTE_CONTEXT ,
221+ cuda_deviceptr ));
222+ if (status != UCS_OK ) {
223+ goto err ;
224+ }
225+
226+ if (cuda_context == NULL ) {
227+ ucs_error ("failed to query cuda context for %p allocated on "
228+ "GPU%d" ,
229+ (void * )cuda_deviceptr , cuda_device );
230+ status = UCS_ERR_UNSUPPORTED ;
231+ goto err ;
232+ }
233+
234+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_context ));
235+ if (status != UCS_OK ) {
236+ goto err ;
237+ }
238+
239+ cuda_device = CU_DEVICE_INVALID ;
240+ }
241+
242+ if (ucs_unlikely (status != UCS_OK )) {
199243 goto err ;
200244 }
201245 } else {
@@ -236,57 +280,83 @@ uct_cuda_copy_ctx_rsc_get(uct_cuda_copy_iface_t *iface,
236280 goto err_pop_and_release ;
237281 }
238282
239- * cuda_device_p = cuda_device ;
240- * ctx_rsc_p = ucs_derived_of (ctx_rsc , uct_cuda_copy_ctx_rsc_t );
283+ * cuda_device_p = cuda_device ;
284+ * cuda_context_p = cuda_context ;
285+ * ctx_rsc_p = ucs_derived_of (ctx_rsc , uct_cuda_copy_ctx_rsc_t );
241286 return UCS_OK ;
242287
243288err_pop_and_release :
244- uct_cuda_primary_ctx_pop_and_release (cuda_device );
289+ uct_cuda_primary_ctx_pop_and_release (cuda_device , cuda_context );
245290err :
246291 return status ;
247292}
248293
294+ static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_get_rma_ctx (
295+ uct_cuda_copy_iface_t * iface , const void * src , const void * dst ,
296+ size_t length , uct_cuda_copy_ep_rma_ctx_t * rma_ctx_p )
297+ {
298+ ucs_memory_type_t src_type ;
299+ ucs_memory_type_t dst_type ;
300+ ucs_sys_device_t sys_dev ;
301+ CUdeviceptr cuda_deviceptr ;
302+ CUdevice cuda_device ;
303+ CUcontext cuda_context ;
304+ uct_cuda_copy_ctx_rsc_t * ctx_rsc ;
305+ ucs_status_t status ;
306+
307+ uct_cuda_copy_get_mem_types (iface -> super .super .md , src , dst , length ,
308+ & src_type , & dst_type , & sys_dev ,
309+ & cuda_deviceptr );
310+
311+ status = uct_cuda_copy_ctx_rsc_get (iface , sys_dev , cuda_deviceptr ,
312+ & cuda_device , & cuda_context , & ctx_rsc );
313+ if (ucs_unlikely (status != UCS_OK )) {
314+ return status ;
315+ }
316+
317+ rma_ctx_p -> src_type = src_type ;
318+ rma_ctx_p -> dst_type = dst_type ;
319+ rma_ctx_p -> cuda_device = cuda_device ;
320+ rma_ctx_p -> cuda_context = cuda_context ;
321+ rma_ctx_p -> ctx_rsc = ctx_rsc ;
322+ return UCS_OK ;
323+ }
324+
249325static UCS_F_ALWAYS_INLINE ucs_status_t
250326uct_cuda_copy_post_cuda_async_copy (uct_ep_h tl_ep , void * dst , void * src ,
251327 size_t length , uct_completion_t * comp )
252328{
253- uct_cuda_copy_iface_t * iface = ucs_derived_of (tl_ep -> iface , uct_cuda_copy_iface_t );
254- uct_base_iface_t * base_iface = ucs_derived_of (tl_ep -> iface , uct_base_iface_t );
255- CUdevice cuda_device ;
256- uct_cuda_event_desc_t * cuda_event ;
257- uct_cuda_queue_desc_t * q_desc ;
329+ uct_cuda_copy_iface_t * iface = ucs_derived_of (tl_ep -> iface ,
330+ uct_cuda_copy_iface_t );
331+ uct_cuda_copy_ep_rma_ctx_t rma_ctx ;
258332 ucs_status_t status ;
259- ucs_memory_type_t src_type ;
260- ucs_memory_type_t dst_type ;
261- CUstream * stream ;
333+ uct_cuda_queue_desc_t * q_desc ;
262334 ucs_queue_head_t * event_q ;
263- uct_cuda_copy_ctx_rsc_t * ctx_rsc ;
264- ucs_sys_device_t sys_dev ;
335+ CUstream * stream ;
336+ uct_cuda_event_desc_t * cuda_event ;
265337
266338 if (!length ) {
267339 return UCS_OK ;
268340 }
269341
270- uct_cuda_copy_get_mem_types (base_iface -> md , src , dst , length , & src_type ,
271- & dst_type , & sys_dev );
272-
273- status = uct_cuda_copy_ctx_rsc_get (iface , sys_dev , & cuda_device , & ctx_rsc );
342+ status = uct_cuda_copy_ep_get_rma_ctx (iface , src , dst , length , & rma_ctx );
274343 if (ucs_unlikely (status != UCS_OK )) {
275344 goto out ;
276345 }
277346
278- q_desc = & ctx_rsc -> queue_desc [src_type ][dst_type ];
279- event_q = & q_desc -> event_queue ;
280- stream = uct_cuda_copy_get_stream (ctx_rsc , src_type , dst_type );
347+ q_desc = & rma_ctx .ctx_rsc -> queue_desc [rma_ctx .src_type ][rma_ctx .dst_type ];
348+ event_q = & q_desc -> event_queue ;
349+ stream = uct_cuda_copy_get_stream (rma_ctx .ctx_rsc , rma_ctx .src_type ,
350+ rma_ctx .dst_type );
281351 if (ucs_unlikely (stream == NULL )) {
282352 ucs_error ("stream for src %s dst %s not available" ,
283- ucs_memory_type_names [src_type ],
284- ucs_memory_type_names [dst_type ]);
353+ ucs_memory_type_names [rma_ctx . src_type ],
354+ ucs_memory_type_names [rma_ctx . dst_type ]);
285355 status = UCS_ERR_IO_ERROR ;
286356 goto out_pop_and_release ;
287357 }
288358
289- cuda_event = ucs_mpool_get (& ctx_rsc -> super .event_mp );
359+ cuda_event = ucs_mpool_get (& rma_ctx . ctx_rsc -> super .event_mp );
290360 if (ucs_unlikely (cuda_event == NULL )) {
291361 ucs_error ("failed to allocate cuda event object" );
292362 status = UCS_ERR_NO_MEMORY ;
@@ -313,15 +383,17 @@ uct_cuda_copy_post_cuda_async_copy(uct_ep_h tl_ep, void *dst, void *src,
313383 cuda_event -> comp = comp ;
314384
315385 UCS_STATIC_BITMAP_SET (& iface -> streams_to_sync ,
316- uct_cuda_copy_flush_bitmap_idx (src_type , dst_type ));
386+ uct_cuda_copy_flush_bitmap_idx (rma_ctx .src_type ,
387+ rma_ctx .dst_type ));
317388
318389 ucs_trace ("cuda async issued: %p dst:%p[%s], src:%p[%s] len:%ld" ,
319- cuda_event , dst , ucs_memory_type_names [dst_type ], src ,
320- ucs_memory_type_names [src_type ], length );
390+ cuda_event , dst , ucs_memory_type_names [rma_ctx . dst_type ], src ,
391+ ucs_memory_type_names [rma_ctx . src_type ], length );
321392 status = UCS_INPROGRESS ;
322393
323394out_pop_and_release :
324- uct_cuda_primary_ctx_pop_and_release (cuda_device );
395+ uct_cuda_primary_ctx_pop_and_release (rma_ctx .cuda_device ,
396+ rma_ctx .cuda_context );
325397out :
326398 return status ;
327399err_mpool_put :
@@ -368,31 +440,24 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_copy_ep_put_zcopy,
368440 uct_iov_total_length (iov , iovcnt ));
369441 uct_cuda_copy_trace_data ("PUT_ZCOPY" , remote_addr , iov , iovcnt );
370442 return status ;
371-
372443}
373444
374445static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_rma_short (
375446 uct_ep_h tl_ep , CUdeviceptr dst , CUdeviceptr src , unsigned length )
376447{
377448 uct_cuda_copy_iface_t * iface = ucs_derived_of (tl_ep -> iface ,
378449 uct_cuda_copy_iface_t );
379- CUdevice cuda_device ;
380- uct_cuda_copy_ctx_rsc_t * ctx_rsc ;
450+ uct_cuda_copy_ep_rma_ctx_t rma_ctx ;
381451 ucs_status_t status ;
382- ucs_memory_type_t src_type ;
383- ucs_memory_type_t dst_type ;
384- ucs_sys_device_t sys_dev ;
385452 CUstream * stream ;
386453
387- uct_cuda_copy_get_mem_types (iface -> super .super .md , (void * )src , (void * )dst ,
388- length , & src_type , & dst_type , & sys_dev );
389-
390- status = uct_cuda_copy_ctx_rsc_get (iface , sys_dev , & cuda_device , & ctx_rsc );
454+ status = uct_cuda_copy_ep_get_rma_ctx (iface , (void * )src , (void * )dst , length ,
455+ & rma_ctx );
391456 if (ucs_unlikely (status != UCS_OK )) {
392457 goto out ;
393458 }
394459
395- stream = & ctx_rsc -> short_stream ;
460+ stream = & rma_ctx . ctx_rsc -> short_stream ;
396461 status = uct_cuda_base_init_stream (stream );
397462 if (ucs_unlikely (status != UCS_OK )) {
398463 goto out_pop_and_release ;
@@ -406,7 +471,8 @@ static UCS_F_ALWAYS_INLINE ucs_status_t uct_cuda_copy_ep_rma_short(
406471 status = UCT_CUDADRV_FUNC_LOG_ERR (cuStreamSynchronize (* stream ));
407472
408473out_pop_and_release :
409- uct_cuda_primary_ctx_pop_and_release (cuda_device );
474+ uct_cuda_primary_ctx_pop_and_release (rma_ctx .cuda_device ,
475+ rma_ctx .cuda_context );
410476out :
411477 return status ;
412478}
0 commit comments