@@ -53,6 +53,10 @@ typedef struct cu_ops_t {
53
53
CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54
54
CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
55
55
CUresult (* cuCtxSetCurrent )(CUcontext ctx );
56
+ CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
57
+ CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
58
+ unsigned int Flags );
59
+ CUresult (* cuIpcCloseMemHandle )(CUdeviceptr dptr );
56
60
} cu_ops_t ;
57
61
58
62
static cu_ops_t g_cu_ops ;
@@ -123,12 +127,20 @@ static void init_cu_global_state(void) {
123
127
utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124
128
* (void * * )& g_cu_ops .cuCtxSetCurrent =
125
129
utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
130
+ * (void * * )& g_cu_ops .cuIpcGetMemHandle =
131
+ utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
132
+ * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
133
+ utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
134
+ * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
135
+ utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
126
136
127
137
if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
128
138
!g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
129
139
!g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
130
140
!g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131
- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
141
+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
142
+ !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
143
+ !g_cu_ops .cuIpcCloseMemHandle ) {
132
144
LOG_ERR ("Required CUDA symbols not found." );
133
145
Init_cu_global_state_failed = true;
134
146
}
@@ -396,6 +408,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
396
408
return "CUDA" ;
397
409
}
398
410
411
+ typedef CUipcMemHandle cu_ipc_data_t ;
412
+
413
+ static umf_result_t cu_memory_provider_get_ipc_handle_size (void * provider ,
414
+ size_t * size ) {
415
+ if (provider == NULL || size == NULL ) {
416
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
417
+ }
418
+
419
+ * size = sizeof (cu_ipc_data_t );
420
+ return UMF_RESULT_SUCCESS ;
421
+ }
422
+
423
+ static umf_result_t cu_memory_provider_get_ipc_handle (void * provider ,
424
+ const void * ptr ,
425
+ size_t size ,
426
+ void * providerIpcData ) {
427
+ (void )size ;
428
+
429
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
430
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
431
+ }
432
+
433
+ CUresult cu_result ;
434
+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
435
+
436
+ cu_result = g_cu_ops .cuIpcGetMemHandle (cu_ipc_data , (CUdeviceptr )ptr );
437
+ if (cu_result != CUDA_SUCCESS ) {
438
+ LOG_ERR ("cuIpcGetMemHandle() failed." );
439
+ return cu2umf_result (cu_result );
440
+ }
441
+
442
+ return UMF_RESULT_SUCCESS ;
443
+ }
444
+
445
+ static umf_result_t cu_memory_provider_put_ipc_handle (void * provider ,
446
+ void * providerIpcData ) {
447
+ if (provider == NULL || providerIpcData == NULL ) {
448
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
449
+ }
450
+
451
+ return UMF_RESULT_SUCCESS ;
452
+ }
453
+
454
+ static umf_result_t cu_memory_provider_open_ipc_handle (void * provider ,
455
+ void * providerIpcData ,
456
+ void * * ptr ) {
457
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
458
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
459
+ }
460
+
461
+ cu_memory_provider_t * cu_provider = (cu_memory_provider_t * )provider ;
462
+
463
+ CUresult cu_result ;
464
+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
465
+
466
+ // Remember current context and set the one from the provider
467
+ CUcontext restore_ctx = NULL ;
468
+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
469
+ if (umf_result != UMF_RESULT_SUCCESS ) {
470
+ return umf_result ;
471
+ }
472
+
473
+ cu_result = g_cu_ops .cuIpcOpenMemHandle ((CUdeviceptr * )ptr , * cu_ipc_data ,
474
+ CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS );
475
+
476
+ if (cu_result != CUDA_SUCCESS ) {
477
+ LOG_ERR ("cuIpcOpenMemHandle() failed." );
478
+ }
479
+
480
+ set_context (restore_ctx , & restore_ctx );
481
+
482
+ return cu2umf_result (cu_result );
483
+ }
484
+
485
+ static umf_result_t
486
+ cu_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
487
+ (void )size ;
488
+
489
+ if (provider == NULL || ptr == NULL ) {
490
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
491
+ }
492
+
493
+ CUresult cu_result ;
494
+
495
+ cu_result = g_cu_ops .cuIpcCloseMemHandle ((CUdeviceptr )ptr );
496
+ if (cu_result != CUDA_SUCCESS ) {
497
+ LOG_ERR ("cuIpcCloseMemHandle() failed." );
498
+ return cu2umf_result (cu_result );
499
+ }
500
+
501
+ return UMF_RESULT_SUCCESS ;
502
+ }
503
+
399
504
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
400
505
.version = UMF_VERSION_CURRENT ,
401
506
.initialize = cu_memory_provider_initialize ,
@@ -412,12 +517,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
412
517
.ext.purge_force = cu_memory_provider_purge_force,
413
518
.ext.allocation_merge = cu_memory_provider_allocation_merge,
414
519
.ext.allocation_split = cu_memory_provider_allocation_split,
520
+ */
415
521
.ipc .get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size ,
416
522
.ipc .get_ipc_handle = cu_memory_provider_get_ipc_handle ,
417
523
.ipc .put_ipc_handle = cu_memory_provider_put_ipc_handle ,
418
524
.ipc .open_ipc_handle = cu_memory_provider_open_ipc_handle ,
419
525
.ipc .close_ipc_handle = cu_memory_provider_close_ipc_handle ,
420
- */
421
526
};
422
527
423
528
umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
0 commit comments