Skip to content

Fix CUDA provider #810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions src/provider/provider_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ typedef struct cu_ops_t {

CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
} cu_ops_t;

static cu_ops_t g_cu_ops;
Expand Down Expand Up @@ -117,11 +119,16 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
*(void **)&g_cu_ops.cuGetErrorString =
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
*(void **)&g_cu_ops.cuCtxGetCurrent =
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
*(void **)&g_cu_ops.cuCtxSetCurrent =
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);

if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString) {
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
LOG_ERR("Required CUDA symbols not found.");
Init_cu_global_state_failed = true;
}
Expand Down Expand Up @@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) {
umf_ba_global_free(provider);
}

/*
* This function is used by the CUDA provider to make sure that
* the required context is set. If the current context is
* not the required one, it will be saved in restore_ctx.
*/
static inline umf_result_t set_context(CUcontext required_ctx,
CUcontext *restore_ctx) {
CUcontext current_ctx = NULL;
CUresult cu_result = g_cu_ops.cuCtxGetCurrent(&current_ctx);
if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuCtxGetCurrent() failed.");
return cu2umf_result(cu_result);
}
*restore_ctx = current_ctx;
if (current_ctx != required_ctx) {
cu_result = g_cu_ops.cuCtxSetCurrent(required_ctx);
if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuCtxSetCurrent() failed.");
return cu2umf_result(cu_result);
}
}

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
size_t alignment,
void **resultPtr) {
Expand All @@ -205,6 +237,14 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
return UMF_RESULT_ERROR_NOT_SUPPORTED;
}

// Remember current context and set the one from the provider
CUcontext restore_ctx = NULL;
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
if (umf_result != UMF_RESULT_SUCCESS) {
LOG_ERR("Failed to set CUDA context, ret = %d", umf_result);
return umf_result;
}

CUresult cu_result = CUDA_SUCCESS;
switch (cu_provider->memory_type) {
case UMF_MEMORY_TYPE_HOST: {
Expand All @@ -224,17 +264,29 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
// this shouldn't happen as we check the memory_type settings during
// the initialization
LOG_ERR("unsupported USM memory type");
assert(false);
return UMF_RESULT_ERROR_UNKNOWN;
}

umf_result = set_context(restore_ctx, &restore_ctx);
if (umf_result != UMF_RESULT_SUCCESS) {
LOG_ERR("Failed to restore CUDA context, ret = %d", umf_result);
}

umf_result = cu2umf_result(cu_result);
if (umf_result != UMF_RESULT_SUCCESS) {
LOG_ERR("Failed to allocate memory, cu_result = %d, ret = %d",
cu_result, umf_result);
return umf_result;
}

// check the alignment
if (alignment > 0 && ((uintptr_t)(*resultPtr) % alignment) != 0) {
cu_memory_provider_free(provider, *resultPtr, size);
LOG_ERR("unsupported alignment size");
return UMF_RESULT_ERROR_INVALID_ALIGNMENT;
}

return cu2umf_result(cu_result);
return umf_result;
}

static umf_result_t cu_memory_provider_free(void *provider, void *ptr,
Expand Down
58 changes: 50 additions & 8 deletions test/providers/cuda_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct libcu_ops {
CUresult (*cuInit)(unsigned int flags);
CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev);
CUresult (*cuCtxDestroy)(CUcontext ctx);
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
CUresult (*cuMemFree)(CUdeviceptr dptr);
Expand All @@ -26,7 +27,9 @@ struct libcu_ops {
CUresult (*cuMemFreeHost)(void *p);
CUresult (*cuMemsetD32)(CUdeviceptr dstDevice, unsigned int pattern,
size_t size);
CUresult (*cuMemcpyDtoH)(void *dstHost, CUdeviceptr srcDevice, size_t size);
CUresult (*cuMemcpy)(CUdeviceptr dst, CUdeviceptr src, size_t size);
CUresult (*cuPointerGetAttribute)(void *data, CUpointer_attribute attribute,
CUdeviceptr ptr);
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
CUpointer_attribute *attributes,
void **data, CUdeviceptr ptr);
Expand Down Expand Up @@ -74,6 +77,12 @@ int InitCUDAOps() {
fprintf(stderr, "cuCtxDestroy symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuCtxGetCurrent =
utils_get_symbol_addr(cuDlHandle.get(), "cuCtxGetCurrent", lib_name);
if (libcu_ops.cuCtxGetCurrent == nullptr) {
fprintf(stderr, "cuCtxGetCurrent symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuDeviceGet =
utils_get_symbol_addr(cuDlHandle.get(), "cuDeviceGet", lib_name);
if (libcu_ops.cuDeviceGet == nullptr) {
Expand Down Expand Up @@ -116,10 +125,17 @@ int InitCUDAOps() {
fprintf(stderr, "cuMemsetD32_v2 symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemcpyDtoH =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpyDtoH_v2", lib_name);
if (libcu_ops.cuMemcpyDtoH == nullptr) {
fprintf(stderr, "cuMemcpyDtoH_v2 symbol not found in %s\n", lib_name);
*(void **)&libcu_ops.cuMemcpy =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpy", lib_name);
if (libcu_ops.cuMemcpy == nullptr) {
fprintf(stderr, "cuMemcpy symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuPointerGetAttribute = utils_get_symbol_addr(
cuDlHandle.get(), "cuPointerGetAttribute", lib_name);
if (libcu_ops.cuPointerGetAttribute == nullptr) {
fprintf(stderr, "cuPointerGetAttribute symbol not found in %s\n",
lib_name);
return -1;
}
*(void **)&libcu_ops.cuPointerGetAttributes = utils_get_symbol_addr(
Expand All @@ -140,14 +156,16 @@ int InitCUDAOps() {
libcu_ops.cuInit = cuInit;
libcu_ops.cuCtxCreate = cuCtxCreate;
libcu_ops.cuCtxDestroy = cuCtxDestroy;
libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent;
libcu_ops.cuDeviceGet = cuDeviceGet;
libcu_ops.cuMemAlloc = cuMemAlloc;
libcu_ops.cuMemAllocHost = cuMemAllocHost;
libcu_ops.cuMemAllocManaged = cuMemAllocManaged;
libcu_ops.cuMemFree = cuMemFree;
libcu_ops.cuMemFreeHost = cuMemFreeHost;
libcu_ops.cuMemsetD32 = cuMemsetD32;
libcu_ops.cuMemcpyDtoH = cuMemcpyDtoH;
libcu_ops.cuMemcpy = cuMemcpy;
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;

return 0;
Expand Down Expand Up @@ -193,9 +211,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
(void)device;

int ret = 0;
CUresult res = libcu_ops.cuMemcpyDtoH(dst_ptr, (CUdeviceptr)src_ptr, size);
CUresult res =
libcu_ops.cuMemcpy((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuMemcpyDtoH() failed!\n");
fprintf(stderr, "cuMemcpy() failed!\n");
return -1;
}

Expand Down Expand Up @@ -230,6 +249,29 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) {
return UMF_MEMORY_TYPE_UNKNOWN;
}

CUcontext get_mem_context(void *ptr) {
CUcontext context;
CUresult res = libcu_ops.cuPointerGetAttribute(
&context, CU_POINTER_ATTRIBUTE_CONTEXT, (CUdeviceptr)ptr);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
return nullptr;
}

return context;
}

CUcontext get_current_context() {
CUcontext context;
CUresult res = libcu_ops.cuCtxGetCurrent(&context);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuCtxGetCurrent() failed!\n");
return nullptr;
}

return context;
}

UTIL_ONCE_FLAG cuda_init_flag;
int InitResult;
void init_cuda_once() {
Expand Down
4 changes: 4 additions & 0 deletions test/providers/cuda_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,

umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr);

CUcontext get_mem_context(void *ptr);

CUcontext get_current_context();

cuda_memory_provider_params_t
create_cuda_prov_params(umf_usm_memory_type_t memory_type);

Expand Down
69 changes: 32 additions & 37 deletions test/providers/provider_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ using namespace umf_test;

class CUDAMemoryAccessor : public MemoryAccessor {
public:
void init(CUcontext hContext, CUdevice hDevice) {
hDevice_ = hDevice;
hContext_ = hContext;
}
CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice)
: hDevice_(hDevice), hContext_(hContext) {}

void fill(void *ptr, size_t size, const void *pattern,
size_t pattern_size) {
Expand Down Expand Up @@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
};

using CUDAProviderTestParams =
std::tuple<umf_usm_memory_type_t, MemoryAccessor *>;
std::tuple<cuda_memory_provider_params_t, MemoryAccessor *>;

struct umfCUDAProviderTest
: umf_test::test,
Expand All @@ -62,23 +60,12 @@ struct umfCUDAProviderTest
void SetUp() override {
test::SetUp();

auto [memory_type, accessor] = this->GetParam();
params = create_cuda_prov_params(memory_type);
auto [cuda_params, accessor] = this->GetParam();
params = cuda_params;
memAccessor = accessor;
if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
((CUDAMemoryAccessor *)memAccessor)
->init((CUcontext)params.cuda_context_handle,
params.cuda_device_handle);
}
}

void TearDown() override {
if (params.cuda_context_handle) {
int ret = destroy_context((CUcontext)params.cuda_context_handle);
ASSERT_EQ(ret, 0);
}
test::TearDown();
}
void TearDown() override { test::TearDown(); }

cuda_memory_provider_params_t params;
MemoryAccessor *memAccessor = nullptr;
Expand All @@ -87,6 +74,7 @@ struct umfCUDAProviderTest
TEST_P(umfCUDAProviderTest, basic) {
const size_t size = 1024 * 8;
const uint32_t pattern = 0xAB;
CUcontext expected_current_context = get_current_context();

// create CUDA provider
umf_memory_provider_handle_t provider = nullptr;
Expand All @@ -113,6 +101,12 @@ TEST_P(umfCUDAProviderTest, basic) {
// use the allocated memory - fill it with a 0xAB pattern
memAccessor->fill(ptr, size, &pattern, sizeof(pattern));

CUcontext actual_mem_context = get_mem_context(ptr);
ASSERT_EQ(actual_mem_context, (CUcontext)params.cuda_context_handle);

CUcontext actual_current_context = get_current_context();
ASSERT_EQ(actual_current_context, expected_current_context);

umf_usm_memory_type_t memoryTypeActual =
get_mem_type((CUcontext)params.cuda_context_handle, ptr);
ASSERT_EQ(memoryTypeActual, params.memory_type);
Expand All @@ -132,6 +126,7 @@ TEST_P(umfCUDAProviderTest, basic) {
}

TEST_P(umfCUDAProviderTest, allocInvalidSize) {
CUcontext expected_current_context = get_current_context();
// create CUDA provider
umf_memory_provider_handle_t provider = nullptr;
umf_result_t umf_result =
Expand All @@ -151,32 +146,32 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

// destroy context and try to alloc some memory
destroy_context((CUcontext)params.cuda_context_handle);
params.cuda_context_handle = 0;
umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr);
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);

const char *message;
int32_t error;
umfMemoryProviderGetLastNativeError(provider, &message, &error);
ASSERT_EQ(error, CUDA_ERROR_INVALID_CONTEXT);
const char *expected_message =
"CUDA_ERROR_INVALID_CONTEXT - invalid device context";
ASSERT_EQ(strncmp(message, expected_message, strlen(expected_message)), 0);
CUcontext actual_current_context = get_current_context();
ASSERT_EQ(actual_current_context, expected_current_context);

umfMemoryProviderDestroy(provider);
}

// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool

CUDAMemoryAccessor cuAccessor;
cuda_memory_provider_params_t cuParams_device_memory =
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
cuda_memory_provider_params_t cuParams_shared_memory =
create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED);
cuda_memory_provider_params_t cuParams_host_memory =
create_cuda_prov_params(UMF_MEMORY_TYPE_HOST);

CUDAMemoryAccessor
cuAccessor((CUcontext)cuParams_device_memory.cuda_context_handle,
(CUdevice)cuParams_device_memory.cuda_device_handle);
HostMemoryAccessor hostAccessor;

INSTANTIATE_TEST_SUITE_P(
umfCUDAProviderTestSuite, umfCUDAProviderTest,
::testing::Values(
CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED, &hostAccessor},
CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST, &hostAccessor}));
CUDAProviderTestParams{cuParams_device_memory, &cuAccessor},
CUDAProviderTestParams{cuParams_shared_memory, &hostAccessor},
CUDAProviderTestParams{cuParams_host_memory, &hostAccessor}));

// TODO: add IPC API
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest);
Expand All @@ -185,5 +180,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
::testing::Values(ipcTestParams{
umfProxyPoolOps(), nullptr,
umfCUDAMemoryProviderOps(),
&cuParams_device_memory, &l0Accessor}));
&cuParams_device_memory, &cuAccessor}));
*/
Loading