|
20 | 20 | from Cython.Build import cythonize
|
21 | 21 | from pyclibrary import CParser
|
22 | 22 | from setuptools import find_packages, setup
|
| 23 | +from setuptools.command.bdist_wheel import bdist_wheel |
23 | 24 | from setuptools.command.build_ext import build_ext
|
24 | 25 | from setuptools.extension import Extension
|
25 | 26 |
|
26 | 27 | # ----------------------------------------------------------------------
|
27 | 28 | # Fetch configuration options
|
28 | 29 |
|
29 |
| -CUDA_HOME = os.environ.get("CUDA_HOME") |
30 |
| -if not CUDA_HOME: |
31 |
| - CUDA_HOME = os.environ.get("CUDA_PATH") |
| 30 | +CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None)) |
32 | 31 | if not CUDA_HOME:
|
33 | 32 | raise RuntimeError("Environment variable CUDA_HOME or CUDA_PATH is not set")
|
34 | 33 |
|
@@ -283,24 +282,49 @@ def do_cythonize(extensions):
|
283 | 282 | extensions += prep_extensions(sources)
|
284 | 283 |
|
285 | 284 | # ---------------------------------------------------------------------
|
286 |
| -# Custom build_ext command |
287 |
| -# Files are build in two steps: |
288 |
| -# 1) Cythonized (in the do_cythonize() command) |
289 |
| -# 2) Compiled to .o files as part of build_ext |
290 |
| -# This class is solely for passing the value of nthreads to build_ext |
| 285 | +# Custom cmdclass extensions |
| 286 | + |
| 287 | +building_wheel = False |
| 288 | + |
| 289 | + |
| 290 | +class WheelsBuildExtensions(bdist_wheel): |
| 291 | + def run(self): |
| 292 | + global building_wheel |
| 293 | + building_wheel = True |
| 294 | + super().run() |
291 | 295 |
|
292 | 296 |
|
293 | 297 | class ParallelBuildExtensions(build_ext):
|
294 | 298 | def initialize_options(self):
|
295 |
| - build_ext.initialize_options(self) |
| 299 | + super().initialize_options() |
296 | 300 | if nthreads > 0:
|
297 | 301 | self.parallel = nthreads
|
298 | 302 |
|
299 |
| - def finalize_options(self): |
300 |
| - build_ext.finalize_options(self) |
| 303 | + def build_extension(self, ext): |
| 304 | + if building_wheel: |
| 305 | + # Strip binaries to remove debug symbols |
| 306 | + extra_linker_flags = ["-Wl,--strip-all"] |
| 307 | + |
| 308 | + # Allow extensions to discover libraries at runtime |
| 309 | + # relative their wheels installation. |
| 310 | + ldflag = "-Wl,--disable-new-dtags" |
| 311 | + if ext.name == "cuda.bindings._bindings.cynvrtc": |
| 312 | + ldflag += f",-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib" |
| 313 | + elif ext.name == "cuda.bindings._internal.nvjitlink": |
| 314 | + ldflag += f",-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib" |
| 315 | + |
| 316 | + extra_linker_flags.append(ldflag) |
| 317 | + else: |
| 318 | + extra_linker_flags = [] |
| 319 | + |
| 320 | + ext.extra_link_args += extra_linker_flags |
| 321 | + super().build_extension(ext) |
301 | 322 |
|
302 | 323 |
|
303 |
| -cmdclass = {"build_ext": ParallelBuildExtensions} |
| 324 | +cmdclass = { |
| 325 | + "bdist_wheel": WheelsBuildExtensions, |
| 326 | + "build_ext": ParallelBuildExtensions, |
| 327 | + } |
304 | 328 |
|
305 | 329 | # ----------------------------------------------------------------------
|
306 | 330 | # Setup
|
|
0 commit comments