Skip to content

Commit 55dc026

Browse files
committed
Increase refcount to CUDA library when CUDA provider is used
1 parent 716cdd2 commit 55dc026

File tree

3 files changed

+51
-18
lines changed

3 files changed

+51
-18
lines changed

src/libumf.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "base_alloc_global.h"
1313
#include "ipc_cache.h"
1414
#include "memspace_internal.h"
15+
#include "provider_cuda_internal.h"
1516
#include "provider_level_zero_internal.h"
1617
#include "provider_tracking.h"
1718
#include "utils_common.h"
@@ -81,6 +82,7 @@ void umfTearDown(void) {
8182

8283
fini_umfTearDown:
8384
fini_ze_global_state();
85+
fini_cu_global_state();
8486
LOG_DEBUG("UMF library finalized");
8587
}
8688
}

src/provider/provider_cuda.c

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@
1212
#include <umf.h>
1313
#include <umf/providers/provider_cuda.h>
1414

15+
#include "provider_cuda_internal.h"
16+
#include "utils_load_library.h"
1517
#include "utils_log.h"
1618

19+
static void *cu_lib_handle = NULL;
20+
21+
void fini_cu_global_state(void) {
22+
if (cu_lib_handle) {
23+
utils_close_library(cu_lib_handle);
24+
cu_lib_handle = NULL;
25+
}
26+
}
27+
1728
#if defined(UMF_NO_CUDA_PROVIDER)
1829

1930
umf_result_t umfCUDAMemoryProviderParamsCreate(
@@ -80,7 +91,6 @@ umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
8091
#include "utils_assert.h"
8192
#include "utils_common.h"
8293
#include "utils_concurrency.h"
83-
#include "utils_load_library.h"
8494
#include "utils_log.h"
8595
#include "utils_sanitizers.h"
8696

@@ -163,37 +173,45 @@ static void init_cu_global_state(void) {
163173
#else
164174
const char *lib_name = "libcuda.so";
165175
#endif
166-
// check if CUDA shared library is already loaded
167-
// we pass 0 as a handle to search the global symbol table
176+
// The CUDA shared library should be already loaded by the user
177+
// of the CUDA provider. UMF just want to reuse it
178+
// and increase the reference count to the the CUDA shared library.
179+
void *lib_handle =
180+
utils_open_library(lib_name, UMF_UTIL_OPEN_LIBRARY_NO_LOAD);
181+
if (!lib_handle) {
182+
LOG_ERR("Failed to open CUDA shared library");
183+
Init_cu_global_state_failed = true;
184+
return;
185+
}
168186

169187
// NOTE: some symbols defined in the lib have _vX postfixes - it is
170188
// important to load the proper version of functions
171-
*(void **)&g_cu_ops.cuMemGetAllocationGranularity =
172-
utils_get_symbol_addr(0, "cuMemGetAllocationGranularity", lib_name);
189+
*(void **)&g_cu_ops.cuMemGetAllocationGranularity = utils_get_symbol_addr(
190+
lib_handle, "cuMemGetAllocationGranularity", lib_name);
173191
*(void **)&g_cu_ops.cuMemAlloc =
174-
utils_get_symbol_addr(0, "cuMemAlloc_v2", lib_name);
192+
utils_get_symbol_addr(lib_handle, "cuMemAlloc_v2", lib_name);
175193
*(void **)&g_cu_ops.cuMemAllocHost =
176-
utils_get_symbol_addr(0, "cuMemAllocHost_v2", lib_name);
194+
utils_get_symbol_addr(lib_handle, "cuMemAllocHost_v2", lib_name);
177195
*(void **)&g_cu_ops.cuMemAllocManaged =
178-
utils_get_symbol_addr(0, "cuMemAllocManaged", lib_name);
196+
utils_get_symbol_addr(lib_handle, "cuMemAllocManaged", lib_name);
179197
*(void **)&g_cu_ops.cuMemFree =
180-
utils_get_symbol_addr(0, "cuMemFree_v2", lib_name);
198+
utils_get_symbol_addr(lib_handle, "cuMemFree_v2", lib_name);
181199
*(void **)&g_cu_ops.cuMemFreeHost =
182-
utils_get_symbol_addr(0, "cuMemFreeHost", lib_name);
200+
utils_get_symbol_addr(lib_handle, "cuMemFreeHost", lib_name);
183201
*(void **)&g_cu_ops.cuGetErrorName =
184-
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
202+
utils_get_symbol_addr(lib_handle, "cuGetErrorName", lib_name);
185203
*(void **)&g_cu_ops.cuGetErrorString =
186-
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
204+
utils_get_symbol_addr(lib_handle, "cuGetErrorString", lib_name);
187205
*(void **)&g_cu_ops.cuCtxGetCurrent =
188-
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
206+
utils_get_symbol_addr(lib_handle, "cuCtxGetCurrent", lib_name);
189207
*(void **)&g_cu_ops.cuCtxSetCurrent =
190-
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
208+
utils_get_symbol_addr(lib_handle, "cuCtxSetCurrent", lib_name);
191209
*(void **)&g_cu_ops.cuIpcGetMemHandle =
192-
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
210+
utils_get_symbol_addr(lib_handle, "cuIpcGetMemHandle", lib_name);
193211
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
194-
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
212+
utils_get_symbol_addr(lib_handle, "cuIpcOpenMemHandle_v2", lib_name);
195213
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
196-
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
214+
utils_get_symbol_addr(lib_handle, "cuIpcCloseMemHandle", lib_name);
197215

198216
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
199217
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
@@ -204,7 +222,10 @@ static void init_cu_global_state(void) {
204222
!g_cu_ops.cuIpcCloseMemHandle) {
205223
LOG_ERR("Required CUDA symbols not found.");
206224
Init_cu_global_state_failed = true;
225+
utils_close_library(lib_handle);
226+
return;
207227
}
228+
cu_lib_handle = lib_handle;
208229
}
209230

210231
umf_result_t umfCUDAMemoryProviderParamsCreate(
@@ -297,7 +318,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
297318
utils_init_once(&cu_is_initialized, init_cu_global_state);
298319
if (Init_cu_global_state_failed) {
299320
LOG_ERR("Loading CUDA symbols failed");
300-
return UMF_RESULT_ERROR_UNKNOWN;
321+
return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE;
301322
}
302323

303324
cu_memory_provider_t *cu_provider =

src/provider/provider_cuda_internal.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
/*
2+
*
3+
* Copyright (C) 2025 Intel Corporation
4+
*
5+
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
6+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
*
8+
*/
9+
10+
void fini_cu_global_state(void);

0 commit comments

Comments
 (0)