Skip to content

Commit d4ab83f

Browse files
committed
Support discovery of nvrtc and nvjitlink libraries at run time
CTK installations distribute their libraries using personal packages: - nvidia-nvjitlink-cuXX - nvidia-cuda-nvrtc-cuXX The relative path of their libraries to cuda-bindings is consistent, and allows us to use relative paths to discover them when loading at run time.
1 parent 8620a28 commit d4ab83f

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

cuda_bindings/setup.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@
2020
from Cython.Build import cythonize
2121
from pyclibrary import CParser
2222
from setuptools import find_packages, setup
23+
from setuptools.command.bdist_wheel import bdist_wheel
2324
from setuptools.command.build_ext import build_ext
2425
from setuptools.extension import Extension
2526

2627
# ----------------------------------------------------------------------
2728
# Fetch configuration options
2829

29-
CUDA_HOME = os.environ.get("CUDA_HOME")
30-
if not CUDA_HOME:
31-
CUDA_HOME = os.environ.get("CUDA_PATH")
30+
CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None))
3231
if not CUDA_HOME:
3332
raise RuntimeError("Environment variable CUDA_HOME or CUDA_PATH is not set")
3433

@@ -283,24 +282,49 @@ def do_cythonize(extensions):
283282
extensions += prep_extensions(sources)
284283

285284
# ---------------------------------------------------------------------
286-
# Custom build_ext command
287-
# Files are build in two steps:
288-
# 1) Cythonized (in the do_cythonize() command)
289-
# 2) Compiled to .o files as part of build_ext
290-
# This class is solely for passing the value of nthreads to build_ext
285+
# Custom cmdclass extensions
286+
287+
building_wheel = False
288+
289+
290+
class WheelsBuildExtensions(bdist_wheel):
291+
def run(self):
292+
global building_wheel
293+
building_wheel = True
294+
super().run()
291295

292296

293297
class ParallelBuildExtensions(build_ext):
294298
def initialize_options(self):
295-
build_ext.initialize_options(self)
299+
super().initialize_options()
296300
if nthreads > 0:
297301
self.parallel = nthreads
298302

299-
def finalize_options(self):
300-
build_ext.finalize_options(self)
303+
def build_extension(self, ext):
304+
if building_wheel:
305+
# Strip binaries to remove debug symbols
306+
extra_linker_flags = ["-Wl,--strip-all"]
307+
308+
# Allow extensions to discover libraries at runtime
309+
# relative their wheels installation.
310+
ldflag = "-Wl,--disable-new-dtags"
311+
if ext.name == "cuda.bindings._bindings.cynvrtc":
312+
ldflag += f",-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
313+
elif ext.name == "cuda.bindings._internal.nvjitlink":
314+
ldflag += f",-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
315+
316+
extra_linker_flags.append(ldflag)
317+
else:
318+
extra_linker_flags = []
319+
320+
ext.extra_link_args += extra_linker_flags
321+
super().build_extension(ext)
301322

302323

303-
cmdclass = {"build_ext": ParallelBuildExtensions}
324+
cmdclass = {
325+
"bdist_wheel": WheelsBuildExtensions,
326+
"build_ext": ParallelBuildExtensions,
327+
}
304328

305329
# ----------------------------------------------------------------------
306330
# Setup

0 commit comments

Comments
 (0)