Skip to content

Commit dd138dc

Browse files
authored
Use path_finder in bindings (nvJitLink, nvrtc, nvvm) (#614)
* Change nvJitLink, nvrtc, nvvm bindings to use path_finder * Restore nvvm-related LD_LIBRARY_PATH, PATH manipulations from main branch. * Update README.md to reflect new search priority
1 parent 938c9e9 commit dd138dc

File tree

8 files changed

+21
-214
lines changed

8 files changed

+21
-214
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
# This code was automatically generated with version 12.9.0. Do not modify it directly.
1010
{{if 'Windows' == platform.system()}}
1111
import os
12-
import site
13-
import struct
1412
import win32api
15-
from pywintypes import error
1613
{{else}}
1714
cimport cuda.bindings._lib.dlfcn as dlfcn
15+
from libc.stdint cimport uintptr_t
1816
{{endif}}
17+
from cuda.bindings import path_finder
1918

2019
from libc.stdint cimport intptr_t
2120

@@ -56,51 +55,10 @@ cdef int cuPythonInit() except -1 nogil:
5655
# Load library
5756
{{if 'Windows' == platform.system()}}
5857
with gil:
59-
# First check if the DLL has been loaded by 3rd parties
60-
try:
61-
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
62-
except:
63-
handle = None
64-
65-
# Check if DLLs can be found within pip installations
66-
if not handle:
67-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
68-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
69-
site_packages = [site.getusersitepackages()] + site.getsitepackages()
70-
for sp in site_packages:
71-
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
72-
if os.path.isdir(mod_path):
73-
os.add_dll_directory(mod_path)
74-
try:
75-
handle = win32api.LoadLibraryEx(
76-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
77-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
78-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
79-
80-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
81-
# located in the same mod_path.
82-
# Update PATH environ so that the two dlls can find each other
83-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
84-
except:
85-
pass
86-
else:
87-
break
88-
else:
89-
# Else try default search
90-
# Only reached if DLL wasn't found in any site-package path
91-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
92-
try:
93-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
94-
except:
95-
pass
96-
97-
if not handle:
98-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
58+
handle = path_finder._load_nvidia_dynamic_library("nvrtc").handle
9959
{{else}}
100-
handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)
101-
if handle == NULL:
102-
with gil:
103-
raise RuntimeError('Failed to dlopen libnvrtc.so.12')
60+
with gil:
61+
handle = <void*><uintptr_t>path_finder._load_nvidia_dynamic_library("nvrtc").handle
10462
{{endif}}
10563

10664

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
#
55
# This code was automatically generated across versions from 12.0.1 to 12.9.0. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
8-
9-
from .utils cimport get_nvjitlink_dso_version_suffix
7+
from libc.stdint cimport intptr_t, uintptr_t
108

119
from .utils import FunctionNotFoundError, NotSupportedError
1210

11+
from cuda.bindings import path_finder
12+
1313
###############################################################################
1414
# Extern
1515
###############################################################################
@@ -52,17 +52,9 @@ cdef void* __nvJitLinkGetInfoLog = NULL
5252
cdef void* __nvJitLinkVersion = NULL
5353

5454

55-
cdef void* load_library(const int driver_ver) except* with gil:
56-
cdef void* handle
57-
for suffix in get_nvjitlink_dso_version_suffix(driver_ver):
58-
so_name = "libnvJitLink.so" + (f".{suffix}" if suffix else suffix)
59-
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
60-
if handle != NULL:
61-
break
62-
else:
63-
err_msg = dlerror()
64-
raise RuntimeError(f'Failed to dlopen libnvJitLink ({err_msg.decode()})')
65-
return handle
55+
cdef void* load_library(int driver_ver) except* with gil:
56+
cdef uintptr_t handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle
57+
return <void*>handle
6658

6759

6860
cdef int _check_or_init_nvjitlink() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
from libc.stdint cimport intptr_t
88

9-
from .utils cimport get_nvjitlink_dso_version_suffix
10-
119
from .utils import FunctionNotFoundError, NotSupportedError
1210

13-
import os
14-
import site
11+
from cuda.bindings import path_finder
1512

1613
import win32api
1714

@@ -42,46 +39,6 @@ cdef void* __nvJitLinkGetInfoLog = NULL
4239
cdef void* __nvJitLinkVersion = NULL
4340

4441

45-
cdef inline list get_site_packages():
46-
return [site.getusersitepackages()] + site.getsitepackages()
47-
48-
49-
cdef load_library(const int driver_ver):
50-
handle = 0
51-
52-
for suffix in get_nvjitlink_dso_version_suffix(driver_ver):
53-
if len(suffix) == 0:
54-
continue
55-
dll_name = f"nvJitLink_{suffix}0_0.dll"
56-
57-
# First check if the DLL has been loaded by 3rd parties
58-
try:
59-
return win32api.GetModuleHandle(dll_name)
60-
except:
61-
pass
62-
63-
# Next, check if DLLs are installed via pip
64-
for sp in get_site_packages():
65-
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
66-
if os.path.isdir(mod_path):
67-
os.add_dll_directory(mod_path)
68-
try:
69-
return win32api.LoadLibraryEx(
70-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
71-
os.path.join(mod_path, dll_name),
72-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
73-
except:
74-
pass
75-
# Finally, try default search
76-
# Only reached if DLL wasn't found in any site-package path
77-
try:
78-
return win32api.LoadLibrary(dll_name)
79-
except:
80-
pass
81-
82-
raise RuntimeError('Failed to load nvJitLink')
83-
84-
8542
cdef int _check_or_init_nvjitlink() except -1 nogil:
8643
global __py_nvjitlink_init
8744
if __py_nvjitlink_init:
@@ -104,7 +61,7 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
10461
raise RuntimeError('something went wrong')
10562

10663
# Load library
107-
handle = load_library(driver_ver)
64+
handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle
10865

10966
# Load function
11067
global __nvJitLinkCreate

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
#
55
# This code was automatically generated across versions from 11.0.3 to 12.9.0. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
8-
9-
from .utils cimport get_nvvm_dso_version_suffix
7+
from libc.stdint cimport intptr_t, uintptr_t
108

119
from .utils import FunctionNotFoundError, NotSupportedError
1210

11+
from cuda.bindings import path_finder
12+
1313
###############################################################################
1414
# Extern
1515
###############################################################################
@@ -51,16 +51,8 @@ cdef void* __nvvmGetProgramLog = NULL
5151

5252

5353
cdef void* load_library(const int driver_ver) except* with gil:
54-
cdef void* handle
55-
for suffix in get_nvvm_dso_version_suffix(driver_ver):
56-
so_name = "libnvvm.so" + (f".{suffix}" if suffix else suffix)
57-
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
58-
if handle != NULL:
59-
break
60-
else:
61-
err_msg = dlerror()
62-
raise RuntimeError(f'Failed to dlopen libnvvm ({err_msg.decode()})')
63-
return handle
54+
cdef uintptr_t handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
55+
return <void*>handle
6456

6557

6658
cdef int _check_or_init_nvvm() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
from libc.stdint cimport intptr_t
88

9-
from .utils cimport get_nvvm_dso_version_suffix
10-
119
from .utils import FunctionNotFoundError, NotSupportedError
1210

13-
import os
14-
import site
11+
from cuda.bindings import path_finder
1512

1613
import win32api
1714

@@ -40,54 +37,6 @@ cdef void* __nvvmGetProgramLogSize = NULL
4037
cdef void* __nvvmGetProgramLog = NULL
4138

4239

43-
cdef inline list get_site_packages():
44-
return [site.getusersitepackages()] + site.getsitepackages() + ["conda"]
45-
46-
47-
cdef load_library(const int driver_ver):
48-
handle = 0
49-
50-
for suffix in get_nvvm_dso_version_suffix(driver_ver):
51-
if len(suffix) == 0:
52-
continue
53-
dll_name = "nvvm64_40_0.dll"
54-
55-
# First check if the DLL has been loaded by 3rd parties
56-
try:
57-
return win32api.GetModuleHandle(dll_name)
58-
except:
59-
pass
60-
61-
# Next, check if DLLs are installed via pip or conda
62-
for sp in get_site_packages():
63-
if sp == "conda":
64-
# nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65-
conda_prefix = os.environ.get("CONDA_PREFIX")
66-
if conda_prefix is None:
67-
continue
68-
mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin")
69-
else:
70-
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
71-
if os.path.isdir(mod_path):
72-
os.add_dll_directory(mod_path)
73-
try:
74-
return win32api.LoadLibraryEx(
75-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76-
os.path.join(mod_path, dll_name),
77-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78-
except:
79-
pass
80-
81-
# Finally, try default search
82-
# Only reached if DLL wasn't found in any site-package path
83-
try:
84-
return win32api.LoadLibrary(dll_name)
85-
except:
86-
pass
87-
88-
raise RuntimeError('Failed to load nvvm')
89-
90-
9140
cdef int _check_or_init_nvvm() except -1 nogil:
9241
global __py_nvvm_init
9342
if __py_nvvm_init:
@@ -110,7 +59,7 @@ cdef int _check_or_init_nvvm() except -1 nogil:
11059
raise RuntimeError('something went wrong')
11160

11261
# Load library
113-
handle = load_library(driver_ver)
62+
handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
11463

11564
# Load function
11665
global __nvvmVersion

cuda_bindings/cuda/bindings/_internal/utils.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,3 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj,
165165

166166
cdef bint is_nested_sequence(data)
167167
cdef void* get_buffer_pointer(buf, Py_ssize_t size, readonly=*) except*
168-
169-
cdef tuple get_nvjitlink_dso_version_suffix(int driver_ver)
170-
cdef tuple get_nvvm_dso_version_suffix(int driver_ver)

cuda_bindings/cuda/bindings/_internal/utils.pyx

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,3 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj,
127127
class FunctionNotFoundError(RuntimeError): pass
128128

129129
class NotSupportedError(RuntimeError): pass
130-
131-
132-
cdef tuple get_nvjitlink_dso_version_suffix(int driver_ver):
133-
if 12000 <= driver_ver < 13000:
134-
return ('12', '')
135-
raise NotSupportedError(f'CUDA driver version {driver_ver} is not supported')
136-
137-
138-
cdef tuple get_nvvm_dso_version_suffix(int driver_ver):
139-
if 11000 <= driver_ver < 11020:
140-
return ('3', '')
141-
if 11020 <= driver_ver < 13000:
142-
return ('4', '')
143-
raise NotSupportedError(f'CUDA driver version {driver_ver} is not supported')

cuda_bindings/setup.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -379,31 +379,7 @@ def initialize_options(self):
379379
def build_extension(self, ext):
380380
if building_wheel and sys.platform == "linux":
381381
# Strip binaries to remove debug symbols
382-
extra_linker_flags = ["-Wl,--strip-all"]
383-
384-
# Allow extensions to discover libraries at runtime
385-
# relative their wheels installation.
386-
if ext.name == "cuda.bindings._bindings.cynvrtc":
387-
ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
388-
elif ext.name == "cuda.bindings._internal.nvjitlink":
389-
ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
390-
elif ext.name == "cuda.bindings._internal.nvvm":
391-
# from <loc>/site-packages/cuda/bindings/_internal/
392-
# to <loc>/site-packages/nvidia/cuda_nvcc/nvvm/lib64/
393-
rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64"
394-
# from <loc>/lib/python3.*/site-packages/cuda/bindings/_internal/
395-
# to <loc>/nvvm/lib64/
396-
rel2 = "$ORIGIN/../../../../../../nvvm/lib64"
397-
ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}"
398-
else:
399-
ldflag = None
400-
401-
if ldflag:
402-
extra_linker_flags.append(ldflag)
403-
else:
404-
extra_linker_flags = []
405-
406-
ext.extra_link_args += extra_linker_flags
382+
ext.extra_link_args.append("-Wl,--strip-all")
407383
super().build_extension(ext)
408384

409385

0 commit comments

Comments
 (0)