Skip to content

Commit c715e62

Browse files
committed
Increase refcount to CUDA library when CUDA provider is used
1 parent 1db4c48 commit c715e62

File tree

3 files changed

+51
-18
lines changed

3 files changed

+51
-18
lines changed

src/libumf.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "base_alloc_global.h"
1313
#include "ipc_cache.h"
1414
#include "memspace_internal.h"
15+
#include "provider_cuda_internal.h"
1516
#include "provider_level_zero_internal.h"
1617
#include "provider_tracking.h"
1718
#include "utils_common.h"
@@ -81,6 +82,7 @@ void umfTearDown(void) {
8182

8283
fini_umfTearDown:
8384
fini_ze_global_state();
85+
fini_cu_global_state();
8486
LOG_DEBUG("UMF library finalized");
8587
}
8688
}

src/provider/provider_cuda.c

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@
1212
#include <umf.h>
1313
#include <umf/providers/provider_cuda.h>
1414

15+
#include "provider_cuda_internal.h"
16+
#include "utils_load_library.h"
1517
#include "utils_log.h"
1618

19+
static void *cu_lib_handle = NULL;
20+
21+
void fini_cu_global_state(void) {
22+
if (cu_lib_handle) {
23+
utils_close_library(cu_lib_handle);
24+
cu_lib_handle = NULL;
25+
}
26+
}
27+
1728
#if defined(UMF_NO_CUDA_PROVIDER)
1829

1930
umf_result_t umfCUDAMemoryProviderParamsCreate(
@@ -88,7 +99,6 @@ umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
8899
#include "utils_assert.h"
89100
#include "utils_common.h"
90101
#include "utils_concurrency.h"
91-
#include "utils_load_library.h"
92102
#include "utils_log.h"
93103
#include "utils_sanitizers.h"
94104

@@ -180,37 +190,45 @@ static void init_cu_global_state(void) {
180190
#else
181191
const char *lib_name = "libcuda.so";
182192
#endif
183-
// check if CUDA shared library is already loaded
184-
// we pass 0 as a handle to search the global symbol table
193+
// The CUDA shared library should be already loaded by the user
194+
// of the CUDA provider. UMF just want to reuse it
195+
// and increase the reference count to the CUDA shared library.
196+
void *lib_handle =
197+
utils_open_library(lib_name, UMF_UTIL_OPEN_LIBRARY_NO_LOAD);
198+
if (!lib_handle) {
199+
LOG_ERR("Failed to open CUDA shared library");
200+
Init_cu_global_state_failed = true;
201+
return;
202+
}
185203

186204
// NOTE: some symbols defined in the lib have _vX postfixes - it is
187205
// important to load the proper version of functions
188-
*(void **)&g_cu_ops.cuMemGetAllocationGranularity =
189-
utils_get_symbol_addr(0, "cuMemGetAllocationGranularity", lib_name);
206+
*(void **)&g_cu_ops.cuMemGetAllocationGranularity = utils_get_symbol_addr(
207+
lib_handle, "cuMemGetAllocationGranularity", lib_name);
190208
*(void **)&g_cu_ops.cuMemAlloc =
191-
utils_get_symbol_addr(0, "cuMemAlloc_v2", lib_name);
209+
utils_get_symbol_addr(lib_handle, "cuMemAlloc_v2", lib_name);
192210
*(void **)&g_cu_ops.cuMemHostAlloc =
193-
utils_get_symbol_addr(0, "cuMemHostAlloc", lib_name);
211+
utils_get_symbol_addr(lib_handle, "cuMemHostAlloc", lib_name);
194212
*(void **)&g_cu_ops.cuMemAllocManaged =
195-
utils_get_symbol_addr(0, "cuMemAllocManaged", lib_name);
213+
utils_get_symbol_addr(lib_handle, "cuMemAllocManaged", lib_name);
196214
*(void **)&g_cu_ops.cuMemFree =
197-
utils_get_symbol_addr(0, "cuMemFree_v2", lib_name);
215+
utils_get_symbol_addr(lib_handle, "cuMemFree_v2", lib_name);
198216
*(void **)&g_cu_ops.cuMemFreeHost =
199-
utils_get_symbol_addr(0, "cuMemFreeHost", lib_name);
217+
utils_get_symbol_addr(lib_handle, "cuMemFreeHost", lib_name);
200218
*(void **)&g_cu_ops.cuGetErrorName =
201-
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
219+
utils_get_symbol_addr(lib_handle, "cuGetErrorName", lib_name);
202220
*(void **)&g_cu_ops.cuGetErrorString =
203-
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
221+
utils_get_symbol_addr(lib_handle, "cuGetErrorString", lib_name);
204222
*(void **)&g_cu_ops.cuCtxGetCurrent =
205-
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
223+
utils_get_symbol_addr(lib_handle, "cuCtxGetCurrent", lib_name);
206224
*(void **)&g_cu_ops.cuCtxSetCurrent =
207-
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
225+
utils_get_symbol_addr(lib_handle, "cuCtxSetCurrent", lib_name);
208226
*(void **)&g_cu_ops.cuIpcGetMemHandle =
209-
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
227+
utils_get_symbol_addr(lib_handle, "cuIpcGetMemHandle", lib_name);
210228
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
211-
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
229+
utils_get_symbol_addr(lib_handle, "cuIpcOpenMemHandle_v2", lib_name);
212230
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
213-
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
231+
utils_get_symbol_addr(lib_handle, "cuIpcCloseMemHandle", lib_name);
214232

215233
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
216234
!g_cu_ops.cuMemHostAlloc || !g_cu_ops.cuMemAllocManaged ||
@@ -221,7 +239,10 @@ static void init_cu_global_state(void) {
221239
!g_cu_ops.cuIpcCloseMemHandle) {
222240
LOG_FATAL("Required CUDA symbols not found.");
223241
Init_cu_global_state_failed = true;
242+
utils_close_library(lib_handle);
243+
return;
224244
}
245+
cu_lib_handle = lib_handle;
225246
}
226247

227248
umf_result_t umfCUDAMemoryProviderParamsCreate(
@@ -327,7 +348,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
327348
utils_init_once(&cu_is_initialized, init_cu_global_state);
328349
if (Init_cu_global_state_failed) {
329350
LOG_FATAL("Loading CUDA symbols failed");
330-
return UMF_RESULT_ERROR_UNKNOWN;
351+
return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE;
331352
}
332353

333354
cu_memory_provider_t *cu_provider =

src/provider/provider_cuda_internal.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
/*
2+
*
3+
* Copyright (C) 2025 Intel Corporation
4+
*
5+
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
6+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
*
8+
*/
9+
10+
void fini_cu_global_state(void);

0 commit comments

Comments
 (0)