Skip to content

Commit feb0001

Browse files
committed
IPC API implementation for CUDA provider
1 parent 11c6408 commit feb0001

File tree

1 file changed

+107
-2
lines changed

1 file changed

+107
-2
lines changed

src/provider/provider_cuda.c

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ typedef struct cu_ops_t {
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
5454
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
5555
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
56+
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
57+
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
58+
unsigned int Flags);
59+
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
5660
} cu_ops_t;
5761

5862
static cu_ops_t g_cu_ops;
@@ -123,12 +127,20 @@ static void init_cu_global_state(void) {
123127
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
124128
*(void **)&g_cu_ops.cuCtxSetCurrent =
125129
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
130+
*(void **)&g_cu_ops.cuIpcGetMemHandle =
131+
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
132+
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
133+
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
134+
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
135+
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
126136

127137
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
128138
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
129139
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
130140
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
131-
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
141+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
142+
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
143+
!g_cu_ops.cuIpcCloseMemHandle) {
132144
LOG_ERR("Required CUDA symbols not found.");
133145
Init_cu_global_state_failed = true;
134146
}
@@ -396,6 +408,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
396408
return "CUDA";
397409
}
398410

411+
typedef CUipcMemHandle cu_ipc_data_t;
412+
413+
static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
414+
size_t *size) {
415+
if (provider == NULL || size == NULL) {
416+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
417+
}
418+
419+
*size = sizeof(cu_ipc_data_t);
420+
return UMF_RESULT_SUCCESS;
421+
}
422+
423+
static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
424+
const void *ptr,
425+
size_t size,
426+
void *providerIpcData) {
427+
(void)size;
428+
429+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
430+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
431+
}
432+
433+
CUresult cu_result;
434+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
435+
436+
cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
437+
if (cu_result != CUDA_SUCCESS) {
438+
LOG_ERR("cuIpcGetMemHandle() failed.");
439+
return cu2umf_result(cu_result);
440+
}
441+
442+
return UMF_RESULT_SUCCESS;
443+
}
444+
445+
static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
446+
void *providerIpcData) {
447+
if (provider == NULL || providerIpcData == NULL) {
448+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
449+
}
450+
451+
return UMF_RESULT_SUCCESS;
452+
}
453+
454+
static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
455+
void *providerIpcData,
456+
void **ptr) {
457+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
458+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
459+
}
460+
461+
cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;
462+
463+
CUresult cu_result;
464+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
465+
466+
// Remember current context and set the one from the provider
467+
CUcontext restore_ctx = NULL;
468+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
469+
if (umf_result != UMF_RESULT_SUCCESS) {
470+
return umf_result;
471+
}
472+
473+
cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
474+
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);
475+
476+
if (cu_result != CUDA_SUCCESS) {
477+
LOG_ERR("cuIpcOpenMemHandle() failed.");
478+
}
479+
480+
set_context(restore_ctx, &restore_ctx);
481+
482+
return cu2umf_result(cu_result);
483+
}
484+
485+
static umf_result_t
486+
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
487+
(void)size;
488+
489+
if (provider == NULL || ptr == NULL) {
490+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
491+
}
492+
493+
CUresult cu_result;
494+
495+
cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
496+
if (cu_result != CUDA_SUCCESS) {
497+
LOG_ERR("cuIpcCloseMemHandle() failed.");
498+
return cu2umf_result(cu_result);
499+
}
500+
501+
return UMF_RESULT_SUCCESS;
502+
}
503+
399504
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
400505
.version = UMF_VERSION_CURRENT,
401506
.initialize = cu_memory_provider_initialize,
@@ -412,12 +517,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
412517
.ext.purge_force = cu_memory_provider_purge_force,
413518
.ext.allocation_merge = cu_memory_provider_allocation_merge,
414519
.ext.allocation_split = cu_memory_provider_allocation_split,
520+
*/
415521
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
416522
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
417523
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
418524
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
419525
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
420-
*/
421526
};
422527

423528
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {

0 commit comments

Comments
 (0)