Skip to content

Commit 400e4ea

Browse files
committed
Support wheels for Windows
1 parent a59ecd4 commit 400e4ea

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88
{{if 'Windows' == platform.system()}}
9-
import win32api
9+
import os
10+
import site
1011
import struct
12+
import win32api
1113
from pywintypes import error
1214
{{else}}
1315
cimport cuda.bindings._lib.dlfcn as dlfcn
@@ -44,11 +46,39 @@ cdef int cuPythonInit() except -1 nogil:
4446

4547
# Load library
4648
{{if 'Windows' == platform.system()}}
47-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
49+
handle = NULL
4850
with gil:
51+
# First check if the DLL has been loaded by 3rd parties
4952
try:
50-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
53+
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
5154
except:
55+
pass
56+
57+
# Try default search
58+
if handle == NULL:
59+
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
60+
try:
61+
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
62+
except:
63+
pass
64+
65+
# Check if DLLs are found within pip installations
66+
if handle == NULL:
67+
site_packages = [site.getusersitepackages()] + site.getsitepackages()
68+
for sp in site_packages:
69+
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
70+
if not os.path.isdir(mod_path):
71+
continue
72+
os.add_dll_directory(mod_path)
73+
try:
74+
handle = win32api.LoadLibraryEx(
75+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76+
os.path.join(mod_path, "nvrtc64_120_0.dll"),
77+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78+
except:
79+
pass
80+
81+
if handle == NULL:
5282
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
5383
{{else}}
5484
handle = NULL

cuda_bindings/docs/source/release/12.x.y-notes.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,22 @@
33
Released on MM DD, 20YY.
44

55
## Highlights
6-
- Added bindings for nvJitLink. It requires nvJitLink from CUDA 12.3 or above.
6+
- Add bindings for nvJitLink. It requires nvJitLink from CUDA 12.3 or above.
7+
- Add optional dependencies to wheels for NVRTC and nvJitLink
8+
- Enable discovery and loading of shared library dependencies from wheels
9+
10+
## Wheels support for optional dependencies
11+
12+
Optional dependencies are added for packages:
13+
14+
- nvidia-nvjitlink-cuXX
15+
- nvidia-cuda-nvrtc-cuXX
16+
17+
Installing these dependencies with cuda-python can be done using:
18+
```{code-block} shell
19+
pip install cuda-python[all]
20+
```
21+
22+
## Discovery and loading of shared library dependencies from wheels
23+
24+
Shared library search paths for wheel builds are now extended to check site-packages. This allows users to seamlessly use their wheel installation of the CUDA Toolkit with cuda-python.

cuda_bindings/pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ dependencies = [
3232
"pywin32; sys_platform == 'win32'",
3333
]
3434

35+
[project.optional-dependencies]
36+
all = [
37+
"nvidia-cuda-nvrtc-cu12",
38+
"nvidia-nvjitlink-cu12>=12.3"
39+
]
40+
3541
[project.urls]
3642
Repository = "https://github.com/NVIDIA/cuda-python"
3743
Documentation = "https://nvidia.github.io/cuda-python/"

cuda_bindings/setup.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,19 +301,21 @@ def initialize_options(self):
301301
self.parallel = nthreads
302302

303303
def build_extension(self, ext):
304-
if building_wheel:
304+
if building_wheel and sys.platform == "linux":
305305
# Strip binaries to remove debug symbols
306306
extra_linker_flags = ["-Wl,--strip-all"]
307307

308308
# Allow extensions to discover libraries at runtime
309309
# relative their wheels installation.
310-
ldflag = "-Wl,--disable-new-dtags"
311310
if ext.name == "cuda.bindings._bindings.cynvrtc":
312-
ldflag += f",-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
311+
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
313312
elif ext.name == "cuda.bindings._internal.nvjitlink":
314-
ldflag += f",-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
313+
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
314+
else:
315+
ldflag = None
315316

316-
extra_linker_flags.append(ldflag)
317+
if ldflag:
318+
extra_linker_flags.append(ldflag)
317319
else:
318320
extra_linker_flags = []
319321

0 commit comments

Comments
 (0)