Skip to content

Commit 1bb7151

Browse files
committed
Rewrite load_library() in nvjitlink_windows.pyx to use path_finder.find_nvidia_dynamic_library()
1 parent c2136ea commit 1bb7151

File tree

1 file changed

+26
-51
lines changed

1 file changed

+26
-51
lines changed

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
from libc.stdint cimport intptr_t
88

9-
from .utils cimport get_nvjitlink_dso_version_suffix
10-
119
from .utils import FunctionNotFoundError, NotSupportedError
1210

13-
import os
14-
import site
11+
from cuda.bindings import path_finder
1512

1613
import win32api
1714

@@ -42,54 +39,32 @@ cdef void* __nvJitLinkGetInfoLog = NULL
4239
cdef void* __nvJitLinkVersion = NULL
4340

4441

45-
cdef inline list get_site_packages():
46-
return [site.getusersitepackages()] + site.getsitepackages()
47-
48-
4942
cdef load_library(const int driver_ver):
50-
handle = 0
51-
52-
for suffix in get_nvjitlink_dso_version_suffix(driver_ver):
53-
if len(suffix) == 0:
54-
continue
55-
dll_name = f"nvJitLink_{suffix}0_0.dll"
56-
57-
# First check if the DLL has been loaded by 3rd parties
58-
try:
59-
handle = win32api.GetModuleHandle(dll_name)
60-
except:
61-
pass
62-
else:
63-
break
64-
65-
# Next, check if DLLs are installed via pip
66-
for sp in get_site_packages():
67-
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
68-
if not os.path.isdir(mod_path):
69-
continue
70-
os.add_dll_directory(mod_path)
71-
try:
72-
handle = win32api.LoadLibraryEx(
73-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
74-
os.path.join(mod_path, dll_name),
75-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
76-
except:
77-
pass
78-
else:
79-
break
80-
81-
# Finally, try default search
82-
try:
83-
handle = win32api.LoadLibrary(dll_name)
84-
except:
85-
pass
86-
else:
87-
break
88-
else:
89-
raise RuntimeError('Failed to load nvJitLink')
90-
91-
assert handle != 0
92-
return handle
43+
dll_name = path_finder.find_nvidia_dynamic_library("nvJitLink")
44+
45+
errors = [f"Failed to load {dll_name}", "Exceptions encountered:"]
46+
47+
# First check if the DLL has been loaded by 3rd parties
48+
try:
49+
return win32api.GetModuleHandle(dll_name)
50+
except BaseException as e:
51+
errors.append(f"{type(e)}: {str(e)}")
52+
53+
try:
54+
return win32api.LoadLibraryEx(
55+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
56+
dll_name,
57+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
58+
except BaseException as e:
59+
errors.append(f"{type(e)}: {str(e)}")
60+
61+
# Finally, try default search
62+
try:
63+
return win32api.LoadLibrary(dll_name)
64+
except BaseException as e:
65+
errors.append(f"{type(e)}: {str(e)}")
66+
67+
raise RuntimeError("\n".join(errors))
9368

9469

9570
cdef int _check_or_init_nvjitlink() except -1 nogil:

0 commit comments

Comments
 (0)