Skip to content

Commit 7a0c068

Browse files
authored
Make path_finder work for "nvrtc" (#553)
* Revert "Restore cuda/bindings/_bindings/cynvrtc.pyx.in as-is on main" This reverts commit ba093f5. * Revert "Reapply "Revert debug changes under .github/workflows"" This reverts commit 8f69f83. * Also load nvrtc from cuda_bindings/tests/path_finder.py * Add heuristics for nvidia_cuda_nvrtc Windows wheels. Also fix a couple bugs discovered by ChatGPT: * `glob.glob()` in this code return absolute paths. * stray `error_messages = []` * Add debug prints, mostly for `os.add_dll_directory(bin_dir)` * Fix unfortunate silly oversight (import os missing under Windows) * Use `win32api.LoadLibraryEx()` with suitable `flags`; also update `os.environ["PATH"]` * Hard-wire WinBase.h constants (they are not exposed by win32con) * Remove debug prints * Reapply "Reapply "Revert debug changes under .github/workflows"" This reverts commit b002ff6.
1 parent 147b242 commit 7a0c068

File tree

4 files changed

+58
-69
lines changed

4 files changed

+58
-69
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
# This code was automatically generated with version 12.8.0. Do not modify it directly.
1010
{{if 'Windows' == platform.system()}}
1111
import os
12-
import site
13-
import struct
1412
import win32api
15-
from pywintypes import error
1613
{{else}}
1714
cimport cuda.bindings._lib.dlfcn as dlfcn
15+
from libc.stdint cimport uintptr_t
1816
{{endif}}
17+
from cuda.bindings import path_finder
1918

2019
cdef bint __cuPythonInit = False
2120
{{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}}
@@ -46,64 +45,18 @@ cdef bint __cuPythonInit = False
4645
{{if 'nvrtcSetFlowCallback' in found_functions}}cdef void *__nvrtcSetFlowCallback = NULL{{endif}}
4746

4847
cdef int cuPythonInit() except -1 nogil:
48+
{{if 'Windows' != platform.system()}}
49+
cdef void* handle = NULL
50+
{{endif}}
51+
4952
global __cuPythonInit
5053
if __cuPythonInit:
5154
return 0
5255
__cuPythonInit = True
5356

54-
# Load library
55-
{{if 'Windows' == platform.system()}}
56-
with gil:
57-
# First check if the DLL has been loaded by 3rd parties
58-
try:
59-
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
60-
except:
61-
handle = None
62-
63-
# Else try default search
64-
if not handle:
65-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
66-
try:
67-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
68-
except:
69-
pass
70-
71-
# Final check if DLLs can be found within pip installations
72-
if not handle:
73-
site_packages = [site.getusersitepackages()] + site.getsitepackages()
74-
for sp in site_packages:
75-
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
76-
if not os.path.isdir(mod_path):
77-
continue
78-
os.add_dll_directory(mod_path)
79-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
80-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
81-
try:
82-
handle = win32api.LoadLibraryEx(
83-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
84-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
85-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
86-
87-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
88-
# located in the same mod_path.
89-
# Update PATH environ so that the two dlls can find each other
90-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
91-
except:
92-
pass
93-
94-
if not handle:
95-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
96-
{{else}}
97-
handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)
98-
if handle == NULL:
99-
with gil:
100-
raise RuntimeError('Failed to dlopen libnvrtc.so.12')
101-
{{endif}}
102-
103-
104-
# Load function
10557
{{if 'Windows' == platform.system()}}
10658
with gil:
59+
handle = path_finder.load_nvidia_dynamic_library("nvrtc")
10760
{{if 'nvrtcGetErrorString' in found_functions}}
10861
try:
10962
global __nvrtcGetErrorString
@@ -288,6 +241,8 @@ cdef int cuPythonInit() except -1 nogil:
288241
{{endif}}
289242

290243
{{else}}
244+
with gil:
245+
handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc")
291246
{{if 'nvrtcGetErrorString' in found_functions}}
292247
global __nvrtcGetErrorString
293248
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')

cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,50 @@ def _find_so_using_nvidia_lib_dirs(libname, so_basename, error_messages, attachm
3131
return so_name
3232
# Look for a versioned library
3333
# Using sort here mainly to make the result deterministic.
34-
for node in sorted(glob.glob(os.path.join(lib_dir, file_wild))):
35-
so_name = os.path.join(lib_dir, node)
34+
for so_name in sorted(glob.glob(os.path.join(lib_dir, file_wild))):
3635
if os.path.isfile(so_name):
3736
return so_name
3837
_no_such_file_in_sub_dirs(nvidia_sub_dirs, file_wild, error_messages, attachments)
3938
return None
4039

4140

41+
def _append_to_os_environ_path(dirpath):
42+
curr_path = os.environ.get("PATH")
43+
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
44+
45+
4246
def _find_dll_using_nvidia_bin_dirs(libname, error_messages, attachments):
4347
if libname == "nvvm": # noqa: SIM108
4448
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin")
4549
else:
4650
nvidia_sub_dirs = ("nvidia", "*", "bin")
4751
file_wild = libname + "*.dll"
4852
for bin_dir in sys_path_find_sub_dirs(nvidia_sub_dirs):
49-
for node in sorted(glob.glob(os.path.join(bin_dir, file_wild))):
50-
dll_name = os.path.join(bin_dir, node)
51-
if os.path.isfile(dll_name):
52-
return dll_name
53+
dll_name = None
54+
have_builtins = False
55+
for path in sorted(glob.glob(os.path.join(bin_dir, file_wild))):
56+
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
57+
# nvidia\cuda_nvrtc\bin\
58+
# nvrtc-builtins64_128.dll
59+
# nvrtc64_120_0.alt.dll
60+
# nvrtc64_120_0.dll
61+
node = os.path.basename(path)
62+
if node.endswith(".alt.dll"):
63+
continue
64+
if "-builtins" in node:
65+
have_builtins = True
66+
continue
67+
if dll_name is not None:
68+
continue
69+
if os.path.isfile(path):
70+
dll_name = path
71+
if dll_name is not None:
72+
if have_builtins:
73+
# Add the DLL directory to the search path
74+
os.add_dll_directory(bin_dir)
75+
# Update PATH as a fallback for dependent DLL resolution
76+
_append_to_os_environ_path(bin_dir)
77+
return dll_name
5378
_no_such_file_in_sub_dirs(nvidia_sub_dirs, file_wild, error_messages, attachments)
5479
return None
5580

@@ -78,7 +103,6 @@ def _find_so_using_cudalib_dir(so_basename, error_messages, attachments):
78103
candidate_so_dirs.append(alt_dir)
79104
libs.reverse()
80105
candidate_so_names = [so_dirname + so_basename for so_dirname in candidate_so_dirs]
81-
error_messages = []
82106
for so_name in candidate_so_names:
83107
if os.path.isfile(so_name):
84108
return so_name
@@ -98,8 +122,7 @@ def _find_dll_using_cudalib_dir(libname, error_messages, attachments):
98122
if cudalib_dir is None:
99123
return None
100124
file_wild = libname + "*.dll"
101-
for node in sorted(glob.glob(os.path.join(cudalib_dir, file_wild))):
102-
dll_name = os.path.join(cudalib_dir, node)
125+
for dll_name in sorted(glob.glob(os.path.join(cudalib_dir, file_wild))):
103126
if os.path.isfile(dll_name):
104127
return dll_name
105128
error_messages.append(f"No such file: {file_wild}")
@@ -123,7 +146,7 @@ def find_nvidia_dynamic_library(name: str) -> str:
123146
dll_name = _find_dll_using_cudalib_dir(name, error_messages, attachments)
124147
if dll_name is None:
125148
attachments = "\n".join(attachments)
126-
raise RuntimeError(f"Failure finding {name}*.dll: {', '.join(error_messages)}\n{attachments}")
149+
raise RuntimeError(f'Failure finding "{name}*.dll": {", ".join(error_messages)}\n{attachments}')
127150
return dll_name
128151

129152
so_basename = f"lib{name}.so"
@@ -135,5 +158,5 @@ def find_nvidia_dynamic_library(name: str) -> str:
135158
so_name = _find_so_using_cudalib_dir(so_basename, error_messages, attachments)
136159
if so_name is None:
137160
attachments = "\n".join(attachments)
138-
raise RuntimeError(f"Failure finding {so_basename}: {', '.join(error_messages)}\n{attachments}")
161+
raise RuntimeError(f'Failure finding "{so_basename}": {", ".join(error_messages)}\n{attachments}')
139162
return so_name

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import win32api
99

1010
# Mirrors WinBase.h (unfortunately not defined already elsewhere)
11-
_WINBASE_LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
11+
_WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
12+
_WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
1213

1314
else:
1415
import ctypes
@@ -77,8 +78,9 @@ def load_nvidia_dynamic_library(name: str) -> int:
7778

7879
dl_path = find_nvidia_dynamic_library(name)
7980
if sys.platform == "win32":
81+
flags = _WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | _WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
8082
try:
81-
handle = win32api.LoadLibrary(dl_path)
83+
handle = win32api.LoadLibraryEx(dl_path, 0, flags)
8284
except pywintypes.error as e:
8385
raise RuntimeError(f"Failed to load DLL at {dl_path}: {e}") from e
8486
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*

cuda_bindings/tests/path_finder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44

55
for k, v in paths.items():
66
print(f"{k}: {v}", flush=True)
7+
print()
78

8-
print(path_finder.find_nvidia_dynamic_library("nvvm"))
9-
print(path_finder.find_nvidia_dynamic_library("nvJitLink"))
9+
libnames = ("nvJitLink", "nvrtc", "nvvm")
10+
11+
for libname in libnames:
12+
print(path_finder.find_nvidia_dynamic_library(libname))
13+
print()
14+
15+
for libname in libnames:
16+
print(libname)
17+
print(path_finder.load_nvidia_dynamic_library(libname))
18+
print()

0 commit comments

Comments
 (0)