Skip to content

Commit 7d8ab70

Browse files
committed
Change nvJitLink, nvrtc, nvvm bindings to use path_finder
1 parent 27db0a7 commit 7d8ab70

File tree

8 files changed

+37
-199
lines changed

8 files changed

+37
-199
lines changed

.github/actions/fetch_ctk/action.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,13 @@ runs:
151151
# mimics actual CTK installation
152152
if [[ "${{ inputs.host-platform }}" == linux* ]]; then
153153
CUDA_PATH=$(realpath "./cuda_toolkit")
154-
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${CUDA_PATH}/lib:${CUDA_PATH}/nvvm/lib64" >> $GITHUB_ENV
154+
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${CUDA_PATH}/lib" >> $GITHUB_ENV
155155
elif [[ "${{ inputs.host-platform }}" == win* ]]; then
156156
function normpath() {
157157
echo "$(echo $(cygpath -w $1) | sed 's/\\/\\\\/g')"
158158
}
159159
CUDA_PATH=$(normpath $(realpath "./cuda_toolkit"))
160160
echo "$(normpath ${CUDA_PATH}/bin)" >> $GITHUB_PATH
161-
echo "$(normpath $CUDA_PATH/nvvm/bin)" >> $GITHUB_PATH
162161
fi
163162
echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV
164163
echo "CUDA_HOME=${CUDA_PATH}" >> $GITHUB_ENV

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
# This code was automatically generated with version 12.9.0. Do not modify it directly.
1010
{{if 'Windows' == platform.system()}}
1111
import os
12-
import site
13-
import struct
1412
import win32api
15-
from pywintypes import error
1613
{{else}}
1714
cimport cuda.bindings._lib.dlfcn as dlfcn
15+
from libc.stdint cimport uintptr_t
1816
{{endif}}
17+
from cuda.bindings import path_finder
1918

2019
from libc.stdint cimport intptr_t
2120

@@ -48,65 +47,18 @@ cdef bint __cuPythonInit = False
4847
{{if 'nvrtcSetFlowCallback' in found_functions}}cdef void *__nvrtcSetFlowCallback = NULL{{endif}}
4948

5049
cdef int cuPythonInit() except -1 nogil:
50+
{{if 'Windows' != platform.system()}}
51+
cdef void* handle = NULL
52+
{{endif}}
53+
5154
global __cuPythonInit
5255
if __cuPythonInit:
5356
return 0
5457
__cuPythonInit = True
5558

56-
# Load library
57-
{{if 'Windows' == platform.system()}}
58-
with gil:
59-
# First check if the DLL has been loaded by 3rd parties
60-
try:
61-
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
62-
except:
63-
handle = None
64-
65-
# Check if DLLs can be found within pip installations
66-
if not handle:
67-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
68-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
69-
site_packages = [site.getusersitepackages()] + site.getsitepackages()
70-
for sp in site_packages:
71-
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
72-
if os.path.isdir(mod_path):
73-
os.add_dll_directory(mod_path)
74-
try:
75-
handle = win32api.LoadLibraryEx(
76-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
77-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
78-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
79-
80-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
81-
# located in the same mod_path.
82-
# Update PATH environ so that the two dlls can find each other
83-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
84-
except:
85-
pass
86-
else:
87-
break
88-
else:
89-
# Else try default search
90-
# Only reached if DLL wasn't found in any site-package path
91-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
92-
try:
93-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
94-
except:
95-
pass
96-
97-
if not handle:
98-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
99-
{{else}}
100-
handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)
101-
if handle == NULL:
102-
with gil:
103-
raise RuntimeError('Failed to dlopen libnvrtc.so.12')
104-
{{endif}}
105-
106-
107-
# Load function
10859
{{if 'Windows' == platform.system()}}
10960
with gil:
61+
handle = path_finder._load_nvidia_dynamic_library("nvrtc").handle
11062
{{if 'nvrtcGetErrorString' in found_functions}}
11163
try:
11264
global __nvrtcGetErrorString
@@ -291,6 +243,8 @@ cdef int cuPythonInit() except -1 nogil:
291243
{{endif}}
292244

293245
{{else}}
246+
with gil:
247+
handle = <void*><uintptr_t>path_finder._load_nvidia_dynamic_library("nvrtc").handle
294248
{{if 'nvrtcGetErrorString' in found_functions}}
295249
global __nvrtcGetErrorString
296250
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
#
55
# This code was automatically generated across versions from 12.0.1 to 12.9.0. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
8-
9-
from .utils cimport get_nvjitlink_dso_version_suffix
7+
from libc.stdint cimport intptr_t, uintptr_t
108

119
from .utils import FunctionNotFoundError, NotSupportedError
1210

11+
from cuda.bindings import path_finder
12+
1313
###############################################################################
1414
# Extern
1515
###############################################################################
@@ -52,17 +52,9 @@ cdef void* __nvJitLinkGetInfoLog = NULL
5252
cdef void* __nvJitLinkVersion = NULL
5353

5454

55-
cdef void* load_library(const int driver_ver) except* with gil:
56-
cdef void* handle
57-
for suffix in get_nvjitlink_dso_version_suffix(driver_ver):
58-
so_name = "libnvJitLink.so" + (f".{suffix}" if suffix else suffix)
59-
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
60-
if handle != NULL:
61-
break
62-
else:
63-
err_msg = dlerror()
64-
raise RuntimeError(f'Failed to dlopen libnvJitLink ({err_msg.decode()})')
65-
return handle
55+
cdef void* load_library(int driver_ver) except* with gil:
56+
cdef uintptr_t handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle
57+
return <void*>handle
6658

6759

6860
cdef int _check_or_init_nvjitlink() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
from libc.stdint cimport intptr_t
88

9-
from .utils cimport get_nvjitlink_dso_version_suffix
10-
119
from .utils import FunctionNotFoundError, NotSupportedError
1210

13-
import os
14-
import site
11+
from cuda.bindings import path_finder
1512

1613
import win32api
1714

@@ -42,44 +39,9 @@ cdef void* __nvJitLinkGetInfoLog = NULL
4239
cdef void* __nvJitLinkVersion = NULL
4340

4441

45-
cdef inline list get_site_packages():
46-
return [site.getusersitepackages()] + site.getsitepackages()
47-
48-
49-
cdef load_library(const int driver_ver):
50-
handle = 0
51-
52-
for suffix in get_nvjitlink_dso_version_suffix(driver_ver):
53-
if len(suffix) == 0:
54-
continue
55-
dll_name = f"nvJitLink_{suffix}0_0.dll"
56-
57-
# First check if the DLL has been loaded by 3rd parties
58-
try:
59-
return win32api.GetModuleHandle(dll_name)
60-
except:
61-
pass
62-
63-
# Next, check if DLLs are installed via pip
64-
for sp in get_site_packages():
65-
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
66-
if os.path.isdir(mod_path):
67-
os.add_dll_directory(mod_path)
68-
try:
69-
return win32api.LoadLibraryEx(
70-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
71-
os.path.join(mod_path, dll_name),
72-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
73-
except:
74-
pass
75-
# Finally, try default search
76-
# Only reached if DLL wasn't found in any site-package path
77-
try:
78-
return win32api.LoadLibrary(dll_name)
79-
except:
80-
pass
81-
82-
raise RuntimeError('Failed to load nvJitLink')
42+
cdef void* load_library(int driver_ver) except* with gil:
43+
cdef intptr_t handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle
44+
return <void*>handle
8345

8446

8547
cdef int _check_or_init_nvjitlink() except -1 nogil:
@@ -88,23 +50,24 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
8850
return 0
8951

9052
cdef int err, driver_ver
53+
cdef intptr_t handle
9154
with gil:
9255
# Load driver to check version
9356
try:
94-
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
57+
nvcuda_handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
9558
except Exception as e:
9659
raise NotSupportedError(f'CUDA driver is not found ({e})')
9760
global __cuDriverGetVersion
9861
if __cuDriverGetVersion == NULL:
99-
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'cuDriverGetVersion')
62+
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
10063
if __cuDriverGetVersion == NULL:
10164
raise RuntimeError('something went wrong')
10265
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
10366
if err != 0:
10467
raise RuntimeError('something went wrong')
10568

10669
# Load library
107-
handle = load_library(driver_ver)
70+
handle = <intptr_t>load_library(driver_ver)
10871

10972
# Load function
11073
global __nvJitLinkCreate

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
#
55
# This code was automatically generated across versions from 11.0.3 to 12.9.0. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
8-
9-
from .utils cimport get_nvvm_dso_version_suffix
7+
from libc.stdint cimport intptr_t, uintptr_t
108

119
from .utils import FunctionNotFoundError, NotSupportedError
1210

11+
from cuda.bindings import path_finder
12+
1313
###############################################################################
1414
# Extern
1515
###############################################################################
@@ -51,16 +51,8 @@ cdef void* __nvvmGetProgramLog = NULL
5151

5252

5353
cdef void* load_library(const int driver_ver) except* with gil:
54-
cdef void* handle
55-
for suffix in get_nvvm_dso_version_suffix(driver_ver):
56-
so_name = "libnvvm.so" + (f".{suffix}" if suffix else suffix)
57-
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
58-
if handle != NULL:
59-
break
60-
else:
61-
err_msg = dlerror()
62-
raise RuntimeError(f'Failed to dlopen libnvvm ({err_msg.decode()})')
63-
return handle
54+
cdef uintptr_t handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
55+
return <void*>handle
6456

6557

6658
cdef int _check_or_init_nvvm() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 8 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
from libc.stdint cimport intptr_t
88

9-
from .utils cimport get_nvvm_dso_version_suffix
10-
119
from .utils import FunctionNotFoundError, NotSupportedError
1210

13-
import os
14-
import site
11+
from cuda.bindings import path_finder
1512

1613
import win32api
1714

@@ -40,52 +37,9 @@ cdef void* __nvvmGetProgramLogSize = NULL
4037
cdef void* __nvvmGetProgramLog = NULL
4138

4239

43-
cdef inline list get_site_packages():
44-
return [site.getusersitepackages()] + site.getsitepackages() + ["conda"]
45-
46-
47-
cdef load_library(const int driver_ver):
48-
handle = 0
49-
50-
for suffix in get_nvvm_dso_version_suffix(driver_ver):
51-
if len(suffix) == 0:
52-
continue
53-
dll_name = "nvvm64_40_0.dll"
54-
55-
# First check if the DLL has been loaded by 3rd parties
56-
try:
57-
return win32api.GetModuleHandle(dll_name)
58-
except:
59-
pass
60-
61-
# Next, check if DLLs are installed via pip or conda
62-
for sp in get_site_packages():
63-
if sp == "conda":
64-
# nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65-
conda_prefix = os.environ.get("CONDA_PREFIX")
66-
if conda_prefix is None:
67-
continue
68-
mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin")
69-
else:
70-
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
71-
if os.path.isdir(mod_path):
72-
os.add_dll_directory(mod_path)
73-
try:
74-
return win32api.LoadLibraryEx(
75-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76-
os.path.join(mod_path, dll_name),
77-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78-
except:
79-
pass
80-
81-
# Finally, try default search
82-
# Only reached if DLL wasn't found in any site-package path
83-
try:
84-
return win32api.LoadLibrary(dll_name)
85-
except:
86-
pass
87-
88-
raise RuntimeError('Failed to load nvvm')
40+
cdef void* load_library(int driver_ver) except* with gil:
41+
cdef intptr_t handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
42+
return <void*>handle
8943

9044

9145
cdef int _check_or_init_nvvm() except -1 nogil:
@@ -94,23 +48,24 @@ cdef int _check_or_init_nvvm() except -1 nogil:
9448
return 0
9549

9650
cdef int err, driver_ver
51+
cdef intptr_t handle
9752
with gil:
9853
# Load driver to check version
9954
try:
100-
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
55+
nvcuda_handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
10156
except Exception as e:
10257
raise NotSupportedError(f'CUDA driver is not found ({e})')
10358
global __cuDriverGetVersion
10459
if __cuDriverGetVersion == NULL:
105-
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'cuDriverGetVersion')
60+
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
10661
if __cuDriverGetVersion == NULL:
10762
raise RuntimeError('something went wrong')
10863
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
10964
if err != 0:
11065
raise RuntimeError('something went wrong')
11166

11267
# Load library
113-
handle = load_library(driver_ver)
68+
handle = <intptr_t>load_library(driver_ver)
11469

11570
# Load function
11671
global __nvvmVersion

cuda_bindings/cuda/bindings/_internal/utils.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,3 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj,
165165

166166
cdef bint is_nested_sequence(data)
167167
cdef void* get_buffer_pointer(buf, Py_ssize_t size, readonly=*) except*
168-
169-
cdef tuple get_nvjitlink_dso_version_suffix(int driver_ver)
170-
cdef tuple get_nvvm_dso_version_suffix(int driver_ver)

cuda_bindings/cuda/bindings/_internal/utils.pyx

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,3 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj,
127127
class FunctionNotFoundError(RuntimeError): pass
128128

129129
class NotSupportedError(RuntimeError): pass
130-
131-
132-
cdef tuple get_nvjitlink_dso_version_suffix(int driver_ver):
133-
if 12000 <= driver_ver < 13000:
134-
return ('12', '')
135-
raise NotSupportedError(f'CUDA driver version {driver_ver} is not supported')
136-
137-
138-
cdef tuple get_nvvm_dso_version_suffix(int driver_ver):
139-
if 11000 <= driver_ver < 11020:
140-
return ('3', '')
141-
if 11020 <= driver_ver < 13000:
142-
return ('4', '')
143-
raise NotSupportedError(f'CUDA driver version {driver_ver} is not supported')

0 commit comments

Comments
 (0)