Skip to content

Commit b3bbdd4

Browse files
authored
Merge pull request #807 from vinser52/svinogra_ipc_cuda
IPC API in CUDA provider
2 parents 43e9af0 + 0da5844 commit b3bbdd4

10 files changed

+288
-4
lines changed

src/provider/provider_cuda.c

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,14 @@ typedef struct cu_ops_t {
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
5454
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
5555
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
56+
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
57+
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
58+
unsigned int Flags);
59+
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
5660
} cu_ops_t;
5761

62+
typedef CUipcMemHandle cu_ipc_data_t;
63+
5864
static cu_ops_t g_cu_ops;
5965
static UTIL_ONCE_FLAG cu_is_initialized = UTIL_ONCE_FLAG_INIT;
6066
static bool Init_cu_global_state_failed;
@@ -123,12 +129,20 @@ static void init_cu_global_state(void) {
123129
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
124130
*(void **)&g_cu_ops.cuCtxSetCurrent =
125131
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
132+
*(void **)&g_cu_ops.cuIpcGetMemHandle =
133+
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
134+
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
135+
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
136+
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
137+
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
126138

127139
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
128140
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
129141
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
130142
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
131-
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
143+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
144+
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
145+
!g_cu_ops.cuIpcCloseMemHandle) {
132146
LOG_ERR("Required CUDA symbols not found.");
133147
Init_cu_global_state_failed = true;
134148
}
@@ -404,6 +418,97 @@ static const char *cu_memory_provider_get_name(void *provider) {
404418
return "CUDA";
405419
}
406420

421+
static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
422+
size_t *size) {
423+
if (provider == NULL || size == NULL) {
424+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
425+
}
426+
427+
*size = sizeof(cu_ipc_data_t);
428+
return UMF_RESULT_SUCCESS;
429+
}
430+
431+
static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
432+
const void *ptr,
433+
size_t size,
434+
void *providerIpcData) {
435+
(void)size;
436+
437+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
438+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
439+
}
440+
441+
CUresult cu_result;
442+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
443+
444+
cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
445+
if (cu_result != CUDA_SUCCESS) {
446+
LOG_ERR("cuIpcGetMemHandle() failed.");
447+
return cu2umf_result(cu_result);
448+
}
449+
450+
return UMF_RESULT_SUCCESS;
451+
}
452+
453+
static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
454+
void *providerIpcData) {
455+
if (provider == NULL || providerIpcData == NULL) {
456+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
457+
}
458+
459+
return UMF_RESULT_SUCCESS;
460+
}
461+
462+
static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
463+
void *providerIpcData,
464+
void **ptr) {
465+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
466+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
467+
}
468+
469+
cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;
470+
471+
CUresult cu_result;
472+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
473+
474+
// Remember current context and set the one from the provider
475+
CUcontext restore_ctx = NULL;
476+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
477+
if (umf_result != UMF_RESULT_SUCCESS) {
478+
return umf_result;
479+
}
480+
481+
cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
482+
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);
483+
484+
if (cu_result != CUDA_SUCCESS) {
485+
LOG_ERR("cuIpcOpenMemHandle() failed.");
486+
}
487+
488+
set_context(restore_ctx, &restore_ctx);
489+
490+
return cu2umf_result(cu_result);
491+
}
492+
493+
static umf_result_t
494+
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
495+
(void)size;
496+
497+
if (provider == NULL || ptr == NULL) {
498+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
499+
}
500+
501+
CUresult cu_result;
502+
503+
cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
504+
if (cu_result != CUDA_SUCCESS) {
505+
LOG_ERR("cuIpcCloseMemHandle() failed.");
506+
return cu2umf_result(cu_result);
507+
}
508+
509+
return UMF_RESULT_SUCCESS;
510+
}
511+
407512
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
408513
.version = UMF_VERSION_CURRENT,
409514
.initialize = cu_memory_provider_initialize,
@@ -420,12 +525,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
420525
.ext.purge_force = cu_memory_provider_purge_force,
421526
.ext.allocation_merge = cu_memory_provider_allocation_merge,
422527
.ext.allocation_split = cu_memory_provider_allocation_split,
528+
*/
423529
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
424530
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
425531
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
426532
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
427533
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
428-
*/
429534
};
430535

431536
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {

test/CMakeLists.txt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,40 @@ if(LINUX)
500500
PRIVATE ${LEVEL_ZERO_INCLUDE_DIRS})
501501
add_umf_ipc_test(TEST ipc_level_zero_prov SRC_DIR providers)
502502
endif()
503+
504+
if(UMF_BUILD_GPU_TESTS
505+
AND UMF_BUILD_CUDA_PROVIDER
506+
AND UMF_BUILD_LIBUMF_POOL_DISJOINT)
507+
build_umf_test(
508+
NAME
509+
ipc_cuda_prov_consumer
510+
SRCS
511+
providers/ipc_cuda_prov_consumer.c
512+
common/ipc_common.c
513+
providers/ipc_cuda_prov_common.c
514+
providers/cuda_helpers.cpp
515+
LIBS
516+
cuda
517+
disjoint_pool
518+
${UMF_UTILS_FOR_TEST})
519+
build_umf_test(
520+
NAME
521+
ipc_cuda_prov_producer
522+
SRCS
523+
providers/ipc_cuda_prov_producer.c
524+
common/ipc_common.c
525+
providers/ipc_cuda_prov_common.c
526+
providers/cuda_helpers.cpp
527+
LIBS
528+
cuda
529+
disjoint_pool
530+
${UMF_UTILS_FOR_TEST})
531+
target_include_directories(umf_test-ipc_cuda_prov_producer
532+
PRIVATE ${CUDA_INCLUDE_DIRS})
533+
target_include_directories(umf_test-ipc_cuda_prov_consumer
534+
PRIVATE ${CUDA_INCLUDE_DIRS})
535+
add_umf_ipc_test(TEST ipc_cuda_prov SRC_DIR providers)
536+
endif()
503537
else()
504538
message(STATUS "IPC tests are supported on Linux only - skipping")
505539
endif()

test/providers/cuda_helpers.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct libcu_ops {
3333
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3434
CUpointer_attribute *attributes,
3535
void **data, CUdeviceptr ptr);
36+
CUresult (*cuStreamSynchronize)(CUstream hStream);
3637
} libcu_ops;
3738

3839
#if USE_DLOPEN
@@ -145,6 +146,13 @@ int InitCUDAOps() {
145146
lib_name);
146147
return -1;
147148
}
149+
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
150+
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
151+
if (libcu_ops.cuStreamSynchronize == nullptr) {
152+
fprintf(stderr, "cuStreamSynchronize symbol not found in %s\n",
153+
lib_name);
154+
return -1;
155+
}
148156

149157
return 0;
150158
}
@@ -167,6 +175,7 @@ int InitCUDAOps() {
167175
libcu_ops.cuMemcpy = cuMemcpy;
168176
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
169177
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
178+
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
170179

171180
return 0;
172181
}
@@ -218,6 +227,12 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
218227
return -1;
219228
}
220229

230+
res = libcu_ops.cuStreamSynchronize(0);
231+
if (res != CUDA_SUCCESS) {
232+
fprintf(stderr, "cuStreamSynchronize() failed!\n");
233+
return -1;
234+
}
235+
221236
return ret;
222237
}
223238

test/providers/ipc_cuda_prov.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#
2+
# Copyright (C) 2024 Intel Corporation
3+
#
4+
# Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
#
7+
8+
#!/bin/bash
9+
10+
set -e
11+
12+
# port should be a number from the range <1024, 65535>
13+
PORT=$(( 1024 + ( $$ % ( 65535 - 1024 ))))
14+
15+
UMF_LOG_VAL="level:debug;flush:debug;output:stderr;pid:yes"
16+
17+
echo "Starting ipc_cuda_prov CONSUMER on port $PORT ..."
18+
UMF_LOG=$UMF_LOG_VAL ./umf_test-ipc_cuda_prov_consumer $PORT &
19+
20+
echo "Waiting 1 sec ..."
21+
sleep 1
22+
23+
echo "Starting ipc_cuda_prov PRODUCER on port $PORT ..."
24+
UMF_LOG=$UMF_LOG_VAL ./umf_test-ipc_cuda_prov_producer $PORT

test/providers/ipc_cuda_prov_common.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (C) 2024 Intel Corporation
3+
*
4+
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*/
7+
8+
#include <stdio.h>
9+
10+
#include <umf/providers/provider_cuda.h>
11+
12+
#include "cuda_helpers.h"
13+
#include "ipc_cuda_prov_common.h"
14+
15+
void memcopy(void *dst, const void *src, size_t size, void *context) {
16+
cuda_memory_provider_params_t *cu_params =
17+
(cuda_memory_provider_params_t *)context;
18+
int ret = cuda_copy(cu_params->cuda_context_handle,
19+
cu_params->cuda_device_handle, dst, src, size);
20+
if (ret != 0) {
21+
fprintf(stderr, "cuda_copy failed with error %d\n", ret);
22+
}
23+
}

test/providers/ipc_cuda_prov_common.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright (C) 2024 Intel Corporation
3+
*
4+
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*/
7+
8+
#ifndef UMF_TEST_IPC_CUDA_PROV_COMMON_H
9+
#define UMF_TEST_IPC_CUDA_PROV_COMMON_H
10+
11+
#include <stddef.h>
12+
13+
void memcopy(void *dst, const void *src, size_t size, void *context);
14+
15+
#endif // UMF_TEST_IPC_CUDA_PROV_COMMON_H
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (C) 2024 Intel Corporation
3+
*
4+
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*/
7+
8+
#include <stdio.h>
9+
#include <stdlib.h>
10+
11+
#include <umf/pools/pool_disjoint.h>
12+
#include <umf/providers/provider_cuda.h>
13+
14+
#include "cuda_helpers.h"
15+
#include "ipc_common.h"
16+
#include "ipc_cuda_prov_common.h"
17+
18+
int main(int argc, char *argv[]) {
19+
if (argc < 2) {
20+
fprintf(stderr, "usage: %s <port>\n", argv[0]);
21+
return -1;
22+
}
23+
24+
int port = atoi(argv[1]);
25+
26+
cuda_memory_provider_params_t cu_params =
27+
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
28+
29+
umf_disjoint_pool_params_t pool_params = umfDisjointPoolParamsDefault();
30+
31+
return run_consumer(port, umfDisjointPoolOps(), &pool_params,
32+
umfCUDAMemoryProviderOps(), &cu_params, memcopy,
33+
&cu_params);
34+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (C) 2024 Intel Corporation
3+
*
4+
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*/
7+
8+
#include <stdio.h>
9+
#include <stdlib.h>
10+
11+
#include <umf/pools/pool_disjoint.h>
12+
#include <umf/providers/provider_cuda.h>
13+
14+
#include "cuda_helpers.h"
15+
#include "ipc_common.h"
16+
#include "ipc_cuda_prov_common.h"
17+
18+
int main(int argc, char *argv[]) {
19+
if (argc < 2) {
20+
fprintf(stderr, "usage: %s <port>\n", argv[0]);
21+
return -1;
22+
}
23+
24+
int port = atoi(argv[1]);
25+
26+
cuda_memory_provider_params_t cu_params =
27+
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
28+
29+
umf_disjoint_pool_params_t pool_params = umfDisjointPoolParamsDefault();
30+
31+
return run_producer(port, umfDisjointPoolOps(), &pool_params,
32+
umfCUDAMemoryProviderOps(), &cu_params, memcopy,
33+
&cu_params);
34+
}

test/providers/ipc_level_zero_prov_consumer.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
int main(int argc, char *argv[]) {
1919
if (argc < 2) {
20-
fprintf(stderr, "usage: %s <port> [shm_name]\n", argv[0]);
20+
fprintf(stderr, "usage: %s <port>\n", argv[0]);
2121
return -1;
2222
}
2323

test/providers/ipc_level_zero_prov_producer.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
int main(int argc, char *argv[]) {
1919
if (argc < 2) {
20-
fprintf(stderr, "usage: %s <port> [shm_name]\n", argv[0]);
20+
fprintf(stderr, "usage: %s <port>\n", argv[0]);
2121
return -1;
2222
}
2323

0 commit comments

Comments
 (0)