9
9
10
10
#include "cuda_ipc_md.h"
11
11
#include "cuda_ipc_cache.h"
12
-
12
+ #include "cuda_ipc.inl"
13
13
#include <string.h>
14
14
#include <limits.h>
15
15
#include <ucs/debug/log.h>
@@ -105,87 +105,13 @@ uct_cuda_ipc_md_query(uct_md_h md, uct_md_attr_v2_t *md_attr)
105
105
return UCS_OK ;
106
106
}
107
107
108
- static ucs_status_t
109
- uct_cuda_ipc_mem_reg_push_ctx (CUdeviceptr address , CUdevice * cuda_device_p ,
110
- int * is_ctx_pushed , int * is_ctx_retained )
111
- {
112
- #define UCT_CUDA_IPC_NUM_ATTRS 2
113
- CUcontext cuda_curr_ctx , cuda_ctx ;
114
- CUdevice cuda_device ;
115
- CUpointer_attribute attr_type [UCT_CUDA_IPC_NUM_ATTRS ];
116
- void * attr_data [UCT_CUDA_IPC_NUM_ATTRS ];
117
- int cuda_device_ordinal ;
118
- ucs_status_t status ;
119
-
120
- attr_type [0 ] = CU_POINTER_ATTRIBUTE_CONTEXT ;
121
- attr_data [0 ] = & cuda_ctx ;
122
- attr_type [1 ] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL ;
123
- attr_data [1 ] = & cuda_device_ordinal ;
124
-
125
- status = UCT_CUDADRV_FUNC_LOG_ERR (
126
- cuPointerGetAttributes (UCT_CUDA_IPC_NUM_ATTRS , attr_type , attr_data ,
127
- address ));
128
- if (status != UCS_OK ) {
129
- return status ;
130
- }
131
-
132
- ucs_assertv (cuda_device_ordinal >= 0 , "cuda_device_ordinal=%d" ,
133
- cuda_device_ordinal );
134
-
135
- status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGet (& cuda_device ,
136
- cuda_device_ordinal ));
137
- if (status != UCS_OK ) {
138
- return status ;
139
- }
140
-
141
- * is_ctx_pushed = 0 ;
142
- * cuda_device_p = cuda_device ;
143
-
144
- if (cuda_ctx == NULL ) {
145
- status = uct_cuda_primary_ctx_retain (* cuda_device_p , 0 , & cuda_ctx );
146
- if (status != UCS_OK ) {
147
- return status ;
148
- }
149
-
150
- * is_ctx_retained = 1 ;
151
- } else {
152
- * is_ctx_retained = 0 ;
153
- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxGetCurrent (& cuda_curr_ctx ));
154
- if ((status != UCS_OK ) || (cuda_curr_ctx == cuda_ctx )) {
155
- /* Failed to get current context or the pointer's context is
156
- * already current, no need to push/pop */
157
- return status ;
158
- }
159
- }
160
-
161
- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_ctx ));
162
- if (status != UCS_OK ) {
163
- if (* is_ctx_retained ) {
164
- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (* cuda_device_p ));
165
- }
166
- return status ;
167
- }
168
-
169
- * is_ctx_pushed = 1 ;
170
- return UCS_OK ;
171
- }
172
-
173
- static void uct_cuda_ipc_mem_reg_pop_ctx (CUdevice cuda_device ,
174
- int is_ctx_retained )
175
- {
176
- UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
177
- if (is_ctx_retained ) {
178
- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_device ));
179
- }
180
- }
181
-
182
108
static ucs_status_t
183
109
uct_cuda_ipc_mem_add_reg (void * addr , uct_cuda_ipc_memh_t * memh ,
184
110
uct_cuda_ipc_lkey_t * * key_p )
185
111
{
186
112
uct_cuda_ipc_lkey_t * key ;
187
113
ucs_status_t status ;
188
- int is_ctx_pushed , is_ctx_retained ;
114
+ int is_ctx_pushed ;
189
115
CUdevice cuda_device ;
190
116
#if HAVE_CUDA_FABRIC
191
117
#define UCT_CUDA_IPC_QUERY_NUM_ATTRS 3
@@ -202,8 +128,8 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
202
128
return UCS_ERR_NO_MEMORY ;
203
129
}
204
130
205
- status = uct_cuda_ipc_mem_reg_push_ctx ((CUdeviceptr )addr , & cuda_device ,
206
- & is_ctx_pushed , & is_ctx_retained );
131
+ status = uct_cuda_ipc_check_and_push_ctx ((CUdeviceptr )addr , & cuda_device ,
132
+ & is_ctx_pushed );
207
133
if (status != UCS_OK ) {
208
134
goto out ;
209
135
}
@@ -313,17 +239,15 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
313
239
ucs_list_add_tail (& memh -> list , & key -> link );
314
240
ucs_trace ("registered addr:%p/%p length:%zd type:%u dev_num:%d "
315
241
"buffer_id:%llu" ,
316
- addr , (void * )key -> d_bptr , key -> b_len , key -> ph .handle_type ,
317
- memh -> dev_num , key -> ph .buffer_id );
242
+ addr , (void * )key -> d_bptr , key -> b_len , key -> ph .handle_type ,
243
+ cuda_device , key -> ph .buffer_id );
318
244
319
245
memh -> dev_num = cuda_device ;
320
246
* key_p = key ;
321
247
status = UCS_OK ;
322
248
323
249
out_pop_ctx :
324
- if (is_ctx_pushed ) {
325
- uct_cuda_ipc_mem_reg_pop_ctx (memh -> dev_num , is_ctx_retained );
326
- }
250
+ uct_cuda_ipc_check_and_pop_ctx (is_ctx_pushed );
327
251
out :
328
252
if (status != UCS_OK ) {
329
253
ucs_free (key );
@@ -394,10 +318,6 @@ uct_cuda_ipc_is_peer_accessible(uct_cuda_ipc_component_t *component,
394
318
}
395
319
}
396
320
397
- /* Save local device number, so we use it to find remote rcache when mapping
398
- * mem_handle in uct_cuda_ipc_post_cuda_async_copy */
399
- rkey -> super .dev_num = cu_dev ;
400
-
401
321
pthread_mutex_lock (& component -> lock );
402
322
403
323
cache = uct_cuda_ipc_get_dev_cache (component , & rkey -> super );
0 commit comments