|
6 | 6 |
|
7 | 7 | from libc.stdint cimport intptr_t
|
8 | 8 |
|
9 |
| -from .utils cimport get_nvjitlink_dso_version_suffix |
10 |
| - |
11 | 9 | from .utils import FunctionNotFoundError, NotSupportedError
|
12 | 10 |
|
13 |
| -import os |
14 |
| -import site |
| 11 | +from cuda.bindings import path_finder |
15 | 12 |
|
16 | 13 | import win32api
|
17 | 14 |
|
@@ -42,54 +39,32 @@ cdef void* __nvJitLinkGetInfoLog = NULL
|
42 | 39 | cdef void* __nvJitLinkVersion = NULL
|
43 | 40 |
|
44 | 41 |
|
45 |
| -cdef inline list get_site_packages(): |
46 |
| - return [site.getusersitepackages()] + site.getsitepackages() |
47 |
| - |
48 |
| - |
49 | 42 | cdef load_library(const int driver_ver):
|
50 |
| - handle = 0 |
51 |
| - |
52 |
| - for suffix in get_nvjitlink_dso_version_suffix(driver_ver): |
53 |
| - if len(suffix) == 0: |
54 |
| - continue |
55 |
| - dll_name = f"nvJitLink_{suffix}0_0.dll" |
56 |
| - |
57 |
| - # First check if the DLL has been loaded by 3rd parties |
58 |
| - try: |
59 |
| - handle = win32api.GetModuleHandle(dll_name) |
60 |
| - except: |
61 |
| - pass |
62 |
| - else: |
63 |
| - break |
64 |
| - |
65 |
| - # Next, check if DLLs are installed via pip |
66 |
| - for sp in get_site_packages(): |
67 |
| - mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin") |
68 |
| - if not os.path.isdir(mod_path): |
69 |
| - continue |
70 |
| - os.add_dll_directory(mod_path) |
71 |
| - try: |
72 |
| - handle = win32api.LoadLibraryEx( |
73 |
| - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... |
74 |
| - os.path.join(mod_path, dll_name), |
75 |
| - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) |
76 |
| - except: |
77 |
| - pass |
78 |
| - else: |
79 |
| - break |
80 |
| - |
81 |
| - # Finally, try default search |
82 |
| - try: |
83 |
| - handle = win32api.LoadLibrary(dll_name) |
84 |
| - except: |
85 |
| - pass |
86 |
| - else: |
87 |
| - break |
88 |
| - else: |
89 |
| - raise RuntimeError('Failed to load nvJitLink') |
90 |
| - |
91 |
| - assert handle != 0 |
92 |
| - return handle |
| 43 | + dll_name = path_finder.find_nvidia_dynamic_library("nvJitLink") |
| 44 | + |
| 45 | + errors = [f"Failed to load {dll_name}", "Exceptions encountered:"] |
| 46 | + |
| 47 | + # First check if the DLL has been loaded by 3rd parties |
| 48 | + try: |
| 49 | + return win32api.GetModuleHandle(dll_name) |
| 50 | + except BaseException as e: |
| 51 | + errors.append(f"{type(e)}: {str(e)}") |
| 52 | + |
| 53 | + try: |
| 54 | + return win32api.LoadLibraryEx( |
| 55 | + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... |
| 56 | + dll_name, |
| 57 | + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) |
| 58 | + except BaseException as e: |
| 59 | + errors.append(f"{type(e)}: {str(e)}") |
| 60 | + |
| 61 | + # Finally, try default search |
| 62 | + try: |
| 63 | + return win32api.LoadLibrary(dll_name) |
| 64 | + except BaseException as e: |
| 65 | + errors.append(f"{type(e)}: {str(e)}") |
| 66 | + |
| 67 | + raise RuntimeError("\n".join(errors)) |
93 | 68 |
|
94 | 69 |
|
95 | 70 | cdef int _check_or_init_nvjitlink() except -1 nogil:
|
|
0 commit comments