@@ -53,6 +53,10 @@ typedef struct cu_ops_t {
53
53
CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54
54
CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
55
55
CUresult (* cuCtxSetCurrent )(CUcontext ctx );
56
+ CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
57
+ CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
58
+ unsigned int Flags );
59
+ CUresult (* cuIpcCloseMemHandle )(CUdeviceptr dptr );
56
60
} cu_ops_t ;
57
61
58
62
static cu_ops_t g_cu_ops ;
@@ -123,12 +127,20 @@ static void init_cu_global_state(void) {
123
127
utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124
128
* (void * * )& g_cu_ops .cuCtxSetCurrent =
125
129
utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
130
+ * (void * * )& g_cu_ops .cuIpcGetMemHandle =
131
+ utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
132
+ * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
133
+ utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
134
+ * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
135
+ utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
126
136
127
137
if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
128
138
!g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
129
139
!g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
130
140
!g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131
- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
141
+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
142
+ !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
143
+ !g_cu_ops .cuIpcCloseMemHandle ) {
132
144
LOG_ERR ("Required CUDA symbols not found." );
133
145
Init_cu_global_state_failed = true;
134
146
}
@@ -404,6 +416,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
404
416
return "CUDA" ;
405
417
}
406
418
419
+ typedef CUipcMemHandle cu_ipc_data_t ;
420
+
421
+ static umf_result_t cu_memory_provider_get_ipc_handle_size (void * provider ,
422
+ size_t * size ) {
423
+ if (provider == NULL || size == NULL ) {
424
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
425
+ }
426
+
427
+ * size = sizeof (cu_ipc_data_t );
428
+ return UMF_RESULT_SUCCESS ;
429
+ }
430
+
431
+ static umf_result_t cu_memory_provider_get_ipc_handle (void * provider ,
432
+ const void * ptr ,
433
+ size_t size ,
434
+ void * providerIpcData ) {
435
+ (void )size ;
436
+
437
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
438
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
439
+ }
440
+
441
+ CUresult cu_result ;
442
+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
443
+
444
+ cu_result = g_cu_ops .cuIpcGetMemHandle (cu_ipc_data , (CUdeviceptr )ptr );
445
+ if (cu_result != CUDA_SUCCESS ) {
446
+ LOG_ERR ("cuIpcGetMemHandle() failed." );
447
+ return cu2umf_result (cu_result );
448
+ }
449
+
450
+ return UMF_RESULT_SUCCESS ;
451
+ }
452
+
453
+ static umf_result_t cu_memory_provider_put_ipc_handle (void * provider ,
454
+ void * providerIpcData ) {
455
+ if (provider == NULL || providerIpcData == NULL ) {
456
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
457
+ }
458
+
459
+ return UMF_RESULT_SUCCESS ;
460
+ }
461
+
462
+ static umf_result_t cu_memory_provider_open_ipc_handle (void * provider ,
463
+ void * providerIpcData ,
464
+ void * * ptr ) {
465
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
466
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
467
+ }
468
+
469
+ cu_memory_provider_t * cu_provider = (cu_memory_provider_t * )provider ;
470
+
471
+ CUresult cu_result ;
472
+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
473
+
474
+ // Remember current context and set the one from the provider
475
+ CUcontext restore_ctx = NULL ;
476
+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
477
+ if (umf_result != UMF_RESULT_SUCCESS ) {
478
+ return umf_result ;
479
+ }
480
+
481
+ cu_result = g_cu_ops .cuIpcOpenMemHandle ((CUdeviceptr * )ptr , * cu_ipc_data ,
482
+ CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS );
483
+
484
+ if (cu_result != CUDA_SUCCESS ) {
485
+ LOG_ERR ("cuIpcOpenMemHandle() failed." );
486
+ }
487
+
488
+ set_context (restore_ctx , & restore_ctx );
489
+
490
+ return cu2umf_result (cu_result );
491
+ }
492
+
493
+ static umf_result_t
494
+ cu_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
495
+ (void )size ;
496
+
497
+ if (provider == NULL || ptr == NULL ) {
498
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
499
+ }
500
+
501
+ CUresult cu_result ;
502
+
503
+ cu_result = g_cu_ops .cuIpcCloseMemHandle ((CUdeviceptr )ptr );
504
+ if (cu_result != CUDA_SUCCESS ) {
505
+ LOG_ERR ("cuIpcCloseMemHandle() failed." );
506
+ return cu2umf_result (cu_result );
507
+ }
508
+
509
+ return UMF_RESULT_SUCCESS ;
510
+ }
511
+
407
512
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
408
513
.version = UMF_VERSION_CURRENT ,
409
514
.initialize = cu_memory_provider_initialize ,
@@ -420,12 +525,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
420
525
.ext.purge_force = cu_memory_provider_purge_force,
421
526
.ext.allocation_merge = cu_memory_provider_allocation_merge,
422
527
.ext.allocation_split = cu_memory_provider_allocation_split,
528
+ */
423
529
.ipc .get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size ,
424
530
.ipc .get_ipc_handle = cu_memory_provider_get_ipc_handle ,
425
531
.ipc .put_ipc_handle = cu_memory_provider_put_ipc_handle ,
426
532
.ipc .open_ipc_handle = cu_memory_provider_open_ipc_handle ,
427
533
.ipc .close_ipc_handle = cu_memory_provider_close_ipc_handle ,
428
- */
429
534
};
430
535
431
536
umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
0 commit comments