Skip to content

Commit 9364e42

Browse files
committed
Eliminate unnecessary (and confusing) conversions between pywintypes.HANDLE, intptr_t, void*
Based on insights obtained while working on a576327
1 parent ee51a52 commit 9364e42

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,29 @@ cdef void* __nvJitLinkGetInfoLog = NULL
3939
cdef void* __nvJitLinkVersion = NULL
4040

4141

42-
cdef void* load_library(int driver_ver) except* with gil:
43-
cdef intptr_t handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle
44-
return <void*>handle
45-
46-
4742
cdef int _check_or_init_nvjitlink() except -1 nogil:
4843
global __py_nvjitlink_init
4944
if __py_nvjitlink_init:
5045
return 0
5146

5247
cdef int err, driver_ver
53-
cdef intptr_t handle
5448
with gil:
5549
# Load driver to check version
5650
try:
57-
nvcuda_handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
51+
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
5852
except Exception as e:
5953
raise NotSupportedError(f'CUDA driver is not found ({e})')
6054
global __cuDriverGetVersion
6155
if __cuDriverGetVersion == NULL:
62-
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
56+
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'cuDriverGetVersion')
6357
if __cuDriverGetVersion == NULL:
6458
raise RuntimeError('something went wrong')
6559
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
6660
if err != 0:
6761
raise RuntimeError('something went wrong')
6862

6963
# Load library
70-
handle = <intptr_t>load_library(driver_ver)
64+
handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle
7165

7266
# Load function
7367
global __nvJitLinkCreate

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,35 +37,29 @@ cdef void* __nvvmGetProgramLogSize = NULL
3737
cdef void* __nvvmGetProgramLog = NULL
3838

3939

40-
cdef void* load_library(int driver_ver) except* with gil:
41-
cdef intptr_t handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
42-
return <void*>handle
43-
44-
4540
cdef int _check_or_init_nvvm() except -1 nogil:
4641
global __py_nvvm_init
4742
if __py_nvvm_init:
4843
return 0
4944

5045
cdef int err, driver_ver
51-
cdef intptr_t handle
5246
with gil:
5347
# Load driver to check version
5448
try:
55-
nvcuda_handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
49+
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)
5650
except Exception as e:
5751
raise NotSupportedError(f'CUDA driver is not found ({e})')
5852
global __cuDriverGetVersion
5953
if __cuDriverGetVersion == NULL:
60-
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
54+
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'cuDriverGetVersion')
6155
if __cuDriverGetVersion == NULL:
6256
raise RuntimeError('something went wrong')
6357
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
6458
if err != 0:
6559
raise RuntimeError('something went wrong')
6660

6761
# Load library
68-
handle = <intptr_t>load_library(driver_ver)
62+
handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
6963

7064
# Load function
7165
global __nvvmVersion

0 commit comments

Comments
 (0)