Skip to content

IPC API in CUDA provider #807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions src/provider/provider_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ typedef struct cu_ops_t {
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
unsigned int Flags);
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
} cu_ops_t;

typedef CUipcMemHandle cu_ipc_data_t;

static cu_ops_t g_cu_ops;
static UTIL_ONCE_FLAG cu_is_initialized = UTIL_ONCE_FLAG_INIT;
static bool Init_cu_global_state_failed;
Expand Down Expand Up @@ -123,12 +129,20 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
*(void **)&g_cu_ops.cuCtxSetCurrent =
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
*(void **)&g_cu_ops.cuIpcGetMemHandle =
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);

if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
!g_cu_ops.cuIpcCloseMemHandle) {
LOG_ERR("Required CUDA symbols not found.");
Init_cu_global_state_failed = true;
}
Expand Down Expand Up @@ -404,6 +418,97 @@ static const char *cu_memory_provider_get_name(void *provider) {
return "CUDA";
}

static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
size_t *size) {
if (provider == NULL || size == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

*size = sizeof(cu_ipc_data_t);
return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
const void *ptr,
size_t size,
void *providerIpcData) {
(void)size;

if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

CUresult cu_result;
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;

cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuIpcGetMemHandle() failed.");
return cu2umf_result(cu_result);
}

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
void *providerIpcData) {
if (provider == NULL || providerIpcData == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
void *providerIpcData,
void **ptr) {
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;

CUresult cu_result;
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;

// Remember current context and set the one from the provider
CUcontext restore_ctx = NULL;
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
if (umf_result != UMF_RESULT_SUCCESS) {
return umf_result;
}

cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);

if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuIpcOpenMemHandle() failed.");
}

set_context(restore_ctx, &restore_ctx);

return cu2umf_result(cu_result);
}

static umf_result_t
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
(void)size;

if (provider == NULL || ptr == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

CUresult cu_result;

cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuIpcCloseMemHandle() failed.");
return cu2umf_result(cu_result);
}

return UMF_RESULT_SUCCESS;
}

static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
.version = UMF_VERSION_CURRENT,
.initialize = cu_memory_provider_initialize,
Expand All @@ -420,12 +525,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
.ext.purge_force = cu_memory_provider_purge_force,
.ext.allocation_merge = cu_memory_provider_allocation_merge,
.ext.allocation_split = cu_memory_provider_allocation_split,
*/
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
*/
};

umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
Expand Down
34 changes: 34 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,40 @@ if(LINUX)
PRIVATE ${LEVEL_ZERO_INCLUDE_DIRS})
add_umf_ipc_test(TEST ipc_level_zero_prov SRC_DIR providers)
endif()

if(UMF_BUILD_GPU_TESTS
AND UMF_BUILD_CUDA_PROVIDER
AND UMF_BUILD_LIBUMF_POOL_DISJOINT)
build_umf_test(
NAME
ipc_cuda_prov_consumer
SRCS
providers/ipc_cuda_prov_consumer.c
common/ipc_common.c
providers/ipc_cuda_prov_common.c
providers/cuda_helpers.cpp
LIBS
cuda
disjoint_pool
${UMF_UTILS_FOR_TEST})
build_umf_test(
NAME
ipc_cuda_prov_producer
SRCS
providers/ipc_cuda_prov_producer.c
common/ipc_common.c
providers/ipc_cuda_prov_common.c
providers/cuda_helpers.cpp
LIBS
cuda
disjoint_pool
${UMF_UTILS_FOR_TEST})
target_include_directories(umf_test-ipc_cuda_prov_producer
PRIVATE ${CUDA_INCLUDE_DIRS})
target_include_directories(umf_test-ipc_cuda_prov_consumer
PRIVATE ${CUDA_INCLUDE_DIRS})
add_umf_ipc_test(TEST ipc_cuda_prov SRC_DIR providers)
endif()
else()
message(STATUS "IPC tests are supported on Linux only - skipping")
endif()
Expand Down
15 changes: 15 additions & 0 deletions test/providers/cuda_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct libcu_ops {
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
CUpointer_attribute *attributes,
void **data, CUdeviceptr ptr);
CUresult (*cuStreamSynchronize)(CUstream hStream);
} libcu_ops;

#if USE_DLOPEN
Expand Down Expand Up @@ -145,6 +146,13 @@ int InitCUDAOps() {
lib_name);
return -1;
}
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
if (libcu_ops.cuStreamSynchronize == nullptr) {
fprintf(stderr, "cuStreamSynchronize symbol not found in %s\n",
lib_name);
return -1;
}

return 0;
}
Expand All @@ -167,6 +175,7 @@ int InitCUDAOps() {
libcu_ops.cuMemcpy = cuMemcpy;
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;

return 0;
}
Expand Down Expand Up @@ -218,6 +227,12 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
return -1;
}

res = libcu_ops.cuStreamSynchronize(0);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuStreamSynchronize() failed!\n");
return -1;
}

return ret;
}

Expand Down
24 changes: 24 additions & 0 deletions test/providers/ipc_cuda_prov.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright (C) 2024 Intel Corporation
#
# Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#

#!/bin/bash

set -e

# port should be a number from the range <1024, 65535>
PORT=$(( 1024 + ( $$ % ( 65535 - 1024 ))))

UMF_LOG_VAL="level:debug;flush:debug;output:stderr;pid:yes"

echo "Starting ipc_cuda_prov CONSUMER on port $PORT ..."
UMF_LOG=$UMF_LOG_VAL ./umf_test-ipc_cuda_prov_consumer $PORT &

echo "Waiting 1 sec ..."
sleep 1

echo "Starting ipc_cuda_prov PRODUCER on port $PORT ..."
UMF_LOG=$UMF_LOG_VAL ./umf_test-ipc_cuda_prov_producer $PORT
23 changes: 23 additions & 0 deletions test/providers/ipc_cuda_prov_common.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (C) 2024 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#include <stdio.h>

#include <umf/providers/provider_cuda.h>

#include "cuda_helpers.h"
#include "ipc_cuda_prov_common.h"

void memcopy(void *dst, const void *src, size_t size, void *context) {
cuda_memory_provider_params_t *cu_params =
(cuda_memory_provider_params_t *)context;
int ret = cuda_copy(cu_params->cuda_context_handle,
cu_params->cuda_device_handle, dst, src, size);
if (ret != 0) {
fprintf(stderr, "cuda_copy failed with error %d\n", ret);
}
}
15 changes: 15 additions & 0 deletions test/providers/ipc_cuda_prov_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright (C) 2024 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#ifndef UMF_TEST_IPC_CUDA_PROV_COMMON_H
#define UMF_TEST_IPC_CUDA_PROV_COMMON_H

#include <stddef.h>

void memcopy(void *dst, const void *src, size_t size, void *context);

#endif // UMF_TEST_IPC_CUDA_PROV_COMMON_H
34 changes: 34 additions & 0 deletions test/providers/ipc_cuda_prov_consumer.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (C) 2024 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#include <stdio.h>
#include <stdlib.h>

#include <umf/pools/pool_disjoint.h>
#include <umf/providers/provider_cuda.h>

#include "cuda_helpers.h"
#include "ipc_common.h"
#include "ipc_cuda_prov_common.h"

int main(int argc, char *argv[]) {
if (argc < 2) {
fprintf(stderr, "usage: %s <port>\n", argv[0]);
return -1;
}

int port = atoi(argv[1]);

cuda_memory_provider_params_t cu_params =
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);

umf_disjoint_pool_params_t pool_params = umfDisjointPoolParamsDefault();

return run_consumer(port, umfDisjointPoolOps(), &pool_params,
umfCUDAMemoryProviderOps(), &cu_params, memcopy,
&cu_params);
}
34 changes: 34 additions & 0 deletions test/providers/ipc_cuda_prov_producer.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (C) 2024 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#include <stdio.h>
#include <stdlib.h>

#include <umf/pools/pool_disjoint.h>
#include <umf/providers/provider_cuda.h>

#include "cuda_helpers.h"
#include "ipc_common.h"
#include "ipc_cuda_prov_common.h"

int main(int argc, char *argv[]) {
if (argc < 2) {
fprintf(stderr, "usage: %s <port>\n", argv[0]);
return -1;
}

int port = atoi(argv[1]);

cuda_memory_provider_params_t cu_params =
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);

umf_disjoint_pool_params_t pool_params = umfDisjointPoolParamsDefault();

return run_producer(port, umfDisjointPoolOps(), &pool_params,
umfCUDAMemoryProviderOps(), &cu_params, memcopy,
&cu_params);
}
2 changes: 1 addition & 1 deletion test/providers/ipc_level_zero_prov_consumer.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

int main(int argc, char *argv[]) {
if (argc < 2) {
fprintf(stderr, "usage: %s <port> [shm_name]\n", argv[0]);
fprintf(stderr, "usage: %s <port>\n", argv[0]);
return -1;
}

Expand Down
2 changes: 1 addition & 1 deletion test/providers/ipc_level_zero_prov_producer.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

int main(int argc, char *argv[]) {
if (argc < 2) {
fprintf(stderr, "usage: %s <port> [shm_name]\n", argv[0]);
fprintf(stderr, "usage: %s <port>\n", argv[0]);
return -1;
}

Expand Down
Loading