Skip to content

Commit 12b3f84

Browse files
committed
Revert "Restore cuda/bindings/_bindings/cynvrtc.pyx.in as-is on main"
This reverts commit ba093f5.
1 parent 147b242 commit 12b3f84

File tree

1 file changed

+9
-54
lines changed

1 file changed

+9
-54
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
# This code was automatically generated with version 12.8.0. Do not modify it directly.
1010
{{if 'Windows' == platform.system()}}
1111
import os
12-
import site
13-
import struct
1412
import win32api
15-
from pywintypes import error
1613
{{else}}
1714
cimport cuda.bindings._lib.dlfcn as dlfcn
15+
from libc.stdint cimport uintptr_t
1816
{{endif}}
17+
from cuda.bindings import path_finder
1918

2019
cdef bint __cuPythonInit = False
2120
{{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}}
@@ -46,64 +45,18 @@ cdef bint __cuPythonInit = False
4645
{{if 'nvrtcSetFlowCallback' in found_functions}}cdef void *__nvrtcSetFlowCallback = NULL{{endif}}
4746

4847
cdef int cuPythonInit() except -1 nogil:
48+
{{if 'Windows' != platform.system()}}
49+
cdef void* handle = NULL
50+
{{endif}}
51+
4952
global __cuPythonInit
5053
if __cuPythonInit:
5154
return 0
5255
__cuPythonInit = True
5356

54-
# Load library
55-
{{if 'Windows' == platform.system()}}
56-
with gil:
57-
# First check if the DLL has been loaded by 3rd parties
58-
try:
59-
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
60-
except:
61-
handle = None
62-
63-
# Else try default search
64-
if not handle:
65-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
66-
try:
67-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
68-
except:
69-
pass
70-
71-
# Final check if DLLs can be found within pip installations
72-
if not handle:
73-
site_packages = [site.getusersitepackages()] + site.getsitepackages()
74-
for sp in site_packages:
75-
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
76-
if not os.path.isdir(mod_path):
77-
continue
78-
os.add_dll_directory(mod_path)
79-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
80-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
81-
try:
82-
handle = win32api.LoadLibraryEx(
83-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
84-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
85-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
86-
87-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
88-
# located in the same mod_path.
89-
# Update PATH environ so that the two dlls can find each other
90-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
91-
except:
92-
pass
93-
94-
if not handle:
95-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
96-
{{else}}
97-
handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)
98-
if handle == NULL:
99-
with gil:
100-
raise RuntimeError('Failed to dlopen libnvrtc.so.12')
101-
{{endif}}
102-
103-
104-
# Load function
10557
{{if 'Windows' == platform.system()}}
10658
with gil:
59+
handle = path_finder.load_nvidia_dynamic_library("nvrtc")
10760
{{if 'nvrtcGetErrorString' in found_functions}}
10861
try:
10962
global __nvrtcGetErrorString
@@ -288,6 +241,8 @@ cdef int cuPythonInit() except -1 nogil:
288241
{{endif}}
289242

290243
{{else}}
244+
with gil:
245+
handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc")
291246
{{if 'nvrtcGetErrorString' in found_functions}}
292247
global __nvrtcGetErrorString
293248
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')

0 commit comments

Comments
 (0)