Skip to content

Commit dfc2ffe

Browse files
committed
IPC API implementation for CUDA provider
1 parent b8088be commit dfc2ffe

File tree

1 file changed

+107
-2
lines changed

1 file changed

+107
-2
lines changed

src/provider/provider_cuda.c

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ typedef struct cu_ops_t {
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
5454
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
5555
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
56+
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
57+
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
58+
unsigned int Flags);
59+
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
5660
} cu_ops_t;
5761

5862
static cu_ops_t g_cu_ops;
@@ -123,12 +127,20 @@ static void init_cu_global_state(void) {
123127
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
124128
*(void **)&g_cu_ops.cuCtxSetCurrent =
125129
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
130+
*(void **)&g_cu_ops.cuIpcGetMemHandle =
131+
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
132+
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
133+
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
134+
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
135+
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
126136

127137
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
128138
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
129139
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
130140
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
131-
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
141+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
142+
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
143+
!g_cu_ops.cuIpcCloseMemHandle) {
132144
LOG_ERR("Required CUDA symbols not found.");
133145
Init_cu_global_state_failed = true;
134146
}
@@ -404,6 +416,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
404416
return "CUDA";
405417
}
406418

419+
typedef CUipcMemHandle cu_ipc_data_t;
420+
421+
static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
422+
size_t *size) {
423+
if (provider == NULL || size == NULL) {
424+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
425+
}
426+
427+
*size = sizeof(cu_ipc_data_t);
428+
return UMF_RESULT_SUCCESS;
429+
}
430+
431+
static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
432+
const void *ptr,
433+
size_t size,
434+
void *providerIpcData) {
435+
(void)size;
436+
437+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
438+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
439+
}
440+
441+
CUresult cu_result;
442+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
443+
444+
cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
445+
if (cu_result != CUDA_SUCCESS) {
446+
LOG_ERR("cuIpcGetMemHandle() failed.");
447+
return cu2umf_result(cu_result);
448+
}
449+
450+
return UMF_RESULT_SUCCESS;
451+
}
452+
453+
static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
454+
void *providerIpcData) {
455+
if (provider == NULL || providerIpcData == NULL) {
456+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
457+
}
458+
459+
return UMF_RESULT_SUCCESS;
460+
}
461+
462+
static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
463+
void *providerIpcData,
464+
void **ptr) {
465+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
466+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
467+
}
468+
469+
cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;
470+
471+
CUresult cu_result;
472+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
473+
474+
// Remember current context and set the one from the provider
475+
CUcontext restore_ctx = NULL;
476+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
477+
if (umf_result != UMF_RESULT_SUCCESS) {
478+
return umf_result;
479+
}
480+
481+
cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
482+
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);
483+
484+
if (cu_result != CUDA_SUCCESS) {
485+
LOG_ERR("cuIpcOpenMemHandle() failed.");
486+
}
487+
488+
set_context(restore_ctx, &restore_ctx);
489+
490+
return cu2umf_result(cu_result);
491+
}
492+
493+
static umf_result_t
494+
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
495+
(void)size;
496+
497+
if (provider == NULL || ptr == NULL) {
498+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
499+
}
500+
501+
CUresult cu_result;
502+
503+
cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
504+
if (cu_result != CUDA_SUCCESS) {
505+
LOG_ERR("cuIpcCloseMemHandle() failed.");
506+
return cu2umf_result(cu_result);
507+
}
508+
509+
return UMF_RESULT_SUCCESS;
510+
}
511+
407512
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
408513
.version = UMF_VERSION_CURRENT,
409514
.initialize = cu_memory_provider_initialize,
@@ -420,12 +525,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
420525
.ext.purge_force = cu_memory_provider_purge_force,
421526
.ext.allocation_merge = cu_memory_provider_allocation_merge,
422527
.ext.allocation_split = cu_memory_provider_allocation_split,
528+
*/
423529
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
424530
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
425531
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
426532
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
427533
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
428-
*/
429534
};
430535

431536
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {

0 commit comments

Comments
 (0)