12
12
#include <umf.h>
13
13
#include <umf/providers/provider_cuda.h>
14
14
15
+ #include "provider_cuda_internal.h"
16
+ #include "utils_load_library.h"
15
17
#include "utils_log.h"
16
18
19
+ static void * cu_lib_handle = NULL ;
20
+
21
+ void fini_cu_global_state (void ) {
22
+ if (cu_lib_handle ) {
23
+ utils_close_library (cu_lib_handle );
24
+ cu_lib_handle = NULL ;
25
+ }
26
+ }
27
+
17
28
#if defined(UMF_NO_CUDA_PROVIDER )
18
29
19
30
umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -80,7 +91,6 @@ umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
80
91
#include "utils_assert.h"
81
92
#include "utils_common.h"
82
93
#include "utils_concurrency.h"
83
- #include "utils_load_library.h"
84
94
#include "utils_log.h"
85
95
#include "utils_sanitizers.h"
86
96
@@ -163,37 +173,45 @@ static void init_cu_global_state(void) {
163
173
#else
164
174
const char * lib_name = "libcuda.so" ;
165
175
#endif
166
- // check if CUDA shared library is already loaded
167
- // we pass 0 as a handle to search the global symbol table
176
+ // The CUDA shared library should be already loaded by the user
177
+ // of the CUDA provider. UMF just want to reuse it
178
+ // and increase the reference count to the the CUDA shared library.
179
+ void * lib_handle =
180
+ utils_open_library (lib_name , UMF_UTIL_OPEN_LIBRARY_NO_LOAD );
181
+ if (!lib_handle ) {
182
+ LOG_ERR ("Failed to open CUDA shared library" );
183
+ Init_cu_global_state_failed = true;
184
+ return ;
185
+ }
168
186
169
187
// NOTE: some symbols defined in the lib have _vX postfixes - it is
170
188
// important to load the proper version of functions
171
- * (void * * )& g_cu_ops .cuMemGetAllocationGranularity =
172
- utils_get_symbol_addr ( 0 , "cuMemGetAllocationGranularity" , lib_name );
189
+ * (void * * )& g_cu_ops .cuMemGetAllocationGranularity = utils_get_symbol_addr (
190
+ lib_handle , "cuMemGetAllocationGranularity" , lib_name );
173
191
* (void * * )& g_cu_ops .cuMemAlloc =
174
- utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
192
+ utils_get_symbol_addr (lib_handle , "cuMemAlloc_v2" , lib_name );
175
193
* (void * * )& g_cu_ops .cuMemAllocHost =
176
- utils_get_symbol_addr (0 , "cuMemAllocHost_v2" , lib_name );
194
+ utils_get_symbol_addr (lib_handle , "cuMemAllocHost_v2" , lib_name );
177
195
* (void * * )& g_cu_ops .cuMemAllocManaged =
178
- utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
196
+ utils_get_symbol_addr (lib_handle , "cuMemAllocManaged" , lib_name );
179
197
* (void * * )& g_cu_ops .cuMemFree =
180
- utils_get_symbol_addr (0 , "cuMemFree_v2" , lib_name );
198
+ utils_get_symbol_addr (lib_handle , "cuMemFree_v2" , lib_name );
181
199
* (void * * )& g_cu_ops .cuMemFreeHost =
182
- utils_get_symbol_addr (0 , "cuMemFreeHost" , lib_name );
200
+ utils_get_symbol_addr (lib_handle , "cuMemFreeHost" , lib_name );
183
201
* (void * * )& g_cu_ops .cuGetErrorName =
184
- utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
202
+ utils_get_symbol_addr (lib_handle , "cuGetErrorName" , lib_name );
185
203
* (void * * )& g_cu_ops .cuGetErrorString =
186
- utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
204
+ utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
187
205
* (void * * )& g_cu_ops .cuCtxGetCurrent =
188
- utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
206
+ utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
189
207
* (void * * )& g_cu_ops .cuCtxSetCurrent =
190
- utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
208
+ utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
191
209
* (void * * )& g_cu_ops .cuIpcGetMemHandle =
192
- utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
210
+ utils_get_symbol_addr (lib_handle , "cuIpcGetMemHandle" , lib_name );
193
211
* (void * * )& g_cu_ops .cuIpcOpenMemHandle =
194
- utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
212
+ utils_get_symbol_addr (lib_handle , "cuIpcOpenMemHandle_v2" , lib_name );
195
213
* (void * * )& g_cu_ops .cuIpcCloseMemHandle =
196
- utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
214
+ utils_get_symbol_addr (lib_handle , "cuIpcCloseMemHandle" , lib_name );
197
215
198
216
if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
199
217
!g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
@@ -204,7 +222,10 @@ static void init_cu_global_state(void) {
204
222
!g_cu_ops .cuIpcCloseMemHandle ) {
205
223
LOG_ERR ("Required CUDA symbols not found." );
206
224
Init_cu_global_state_failed = true;
225
+ utils_close_library (lib_handle );
226
+ return ;
207
227
}
228
+ cu_lib_handle = lib_handle ;
208
229
}
209
230
210
231
umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -297,7 +318,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
297
318
utils_init_once (& cu_is_initialized , init_cu_global_state );
298
319
if (Init_cu_global_state_failed ) {
299
320
LOG_ERR ("Loading CUDA symbols failed" );
300
- return UMF_RESULT_ERROR_UNKNOWN ;
321
+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
301
322
}
302
323
303
324
cu_memory_provider_t * cu_provider =
0 commit comments