12
12
#include <umf.h>
13
13
#include <umf/providers/provider_cuda.h>
14
14
15
+ #include "provider_cuda_internal.h"
16
+ #include "utils_load_library.h"
15
17
#include "utils_log.h"
16
18
19
+ static void * cu_lib_handle = NULL ;
20
+
21
+ void fini_cu_global_state (void ) {
22
+ if (cu_lib_handle ) {
23
+ utils_close_library (cu_lib_handle );
24
+ cu_lib_handle = NULL ;
25
+ }
26
+ }
27
+
17
28
#if defined(UMF_NO_CUDA_PROVIDER )
18
29
19
30
umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -88,7 +99,6 @@ umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
88
99
#include "utils_assert.h"
89
100
#include "utils_common.h"
90
101
#include "utils_concurrency.h"
91
- #include "utils_load_library.h"
92
102
#include "utils_log.h"
93
103
#include "utils_sanitizers.h"
94
104
@@ -180,37 +190,45 @@ static void init_cu_global_state(void) {
180
190
#else
181
191
const char * lib_name = "libcuda.so" ;
182
192
#endif
183
- // check if CUDA shared library is already loaded
184
- // we pass 0 as a handle to search the global symbol table
193
+ // The CUDA shared library should be already loaded by the user
194
+ // of the CUDA provider. UMF just want to reuse it
195
+ // and increase the reference count to the CUDA shared library.
196
+ void * lib_handle =
197
+ utils_open_library (lib_name , UMF_UTIL_OPEN_LIBRARY_NO_LOAD );
198
+ if (!lib_handle ) {
199
+ LOG_ERR ("Failed to open CUDA shared library" );
200
+ Init_cu_global_state_failed = true;
201
+ return ;
202
+ }
185
203
186
204
// NOTE: some symbols defined in the lib have _vX postfixes - it is
187
205
// important to load the proper version of functions
188
- * (void * * )& g_cu_ops .cuMemGetAllocationGranularity =
189
- utils_get_symbol_addr ( 0 , "cuMemGetAllocationGranularity" , lib_name );
206
+ * (void * * )& g_cu_ops .cuMemGetAllocationGranularity = utils_get_symbol_addr (
207
+ lib_handle , "cuMemGetAllocationGranularity" , lib_name );
190
208
* (void * * )& g_cu_ops .cuMemAlloc =
191
- utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
209
+ utils_get_symbol_addr (lib_handle , "cuMemAlloc_v2" , lib_name );
192
210
* (void * * )& g_cu_ops .cuMemHostAlloc =
193
- utils_get_symbol_addr (0 , "cuMemHostAlloc" , lib_name );
211
+ utils_get_symbol_addr (lib_handle , "cuMemHostAlloc" , lib_name );
194
212
* (void * * )& g_cu_ops .cuMemAllocManaged =
195
- utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
213
+ utils_get_symbol_addr (lib_handle , "cuMemAllocManaged" , lib_name );
196
214
* (void * * )& g_cu_ops .cuMemFree =
197
- utils_get_symbol_addr (0 , "cuMemFree_v2" , lib_name );
215
+ utils_get_symbol_addr (lib_handle , "cuMemFree_v2" , lib_name );
198
216
* (void * * )& g_cu_ops .cuMemFreeHost =
199
- utils_get_symbol_addr (0 , "cuMemFreeHost" , lib_name );
217
+ utils_get_symbol_addr (lib_handle , "cuMemFreeHost" , lib_name );
200
218
* (void * * )& g_cu_ops .cuGetErrorName =
201
- utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
219
+ utils_get_symbol_addr (lib_handle , "cuGetErrorName" , lib_name );
202
220
* (void * * )& g_cu_ops .cuGetErrorString =
203
- utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
221
+ utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
204
222
* (void * * )& g_cu_ops .cuCtxGetCurrent =
205
- utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
223
+ utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
206
224
* (void * * )& g_cu_ops .cuCtxSetCurrent =
207
- utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
225
+ utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
208
226
* (void * * )& g_cu_ops .cuIpcGetMemHandle =
209
- utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
227
+ utils_get_symbol_addr (lib_handle , "cuIpcGetMemHandle" , lib_name );
210
228
* (void * * )& g_cu_ops .cuIpcOpenMemHandle =
211
- utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
229
+ utils_get_symbol_addr (lib_handle , "cuIpcOpenMemHandle_v2" , lib_name );
212
230
* (void * * )& g_cu_ops .cuIpcCloseMemHandle =
213
- utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
231
+ utils_get_symbol_addr (lib_handle , "cuIpcCloseMemHandle" , lib_name );
214
232
215
233
if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
216
234
!g_cu_ops .cuMemHostAlloc || !g_cu_ops .cuMemAllocManaged ||
@@ -221,7 +239,10 @@ static void init_cu_global_state(void) {
221
239
!g_cu_ops .cuIpcCloseMemHandle ) {
222
240
LOG_FATAL ("Required CUDA symbols not found." );
223
241
Init_cu_global_state_failed = true;
242
+ utils_close_library (lib_handle );
243
+ return ;
224
244
}
245
+ cu_lib_handle = lib_handle ;
225
246
}
226
247
227
248
umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -327,7 +348,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
327
348
utils_init_once (& cu_is_initialized , init_cu_global_state );
328
349
if (Init_cu_global_state_failed ) {
329
350
LOG_FATAL ("Loading CUDA symbols failed" );
330
- return UMF_RESULT_ERROR_UNKNOWN ;
351
+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
331
352
}
332
353
333
354
cu_memory_provider_t * cu_provider =
0 commit comments