Skip to content

Commit f47ac40

Browse files
committed
fix binding version check
1 parent b647dfc commit f47ac40

File tree

5 files changed

+17
-12
lines changed

5 files changed

+17
-12
lines changed

cuda_bindings/setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,9 @@ def build_extension(self, ext):
308308
# Allow extensions to discover libraries at runtime
309309
# relative their wheels installation.
310310
if ext.name == "cuda.bindings._bindings.cynvrtc":
311-
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
311+
ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
312312
elif ext.name == "cuda.bindings._internal.nvjitlink":
313-
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
313+
ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
314314
else:
315315
ldflag = None
316316

@@ -326,7 +326,7 @@ def build_extension(self, ext):
326326
cmdclass = {
327327
"bdist_wheel": WheelsBuildExtensions,
328328
"build_ext": ParallelBuildExtensions,
329-
}
329+
}
330330

331331
# ----------------------------------------------------------------------
332332
# Setup

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def ptx_header(version, arch):
5555
def check_nvjitlink_usable():
5656
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
5757

58-
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
59-
return False
60-
return True
58+
return inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") != 0
6159

6260

6361
pytestmark = pytest.mark.skipif(

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5-
import importlib.metadata
65
from dataclasses import dataclass
76
from typing import Optional, Union
87

@@ -11,7 +10,7 @@
1110
from cuda.core.experimental._kernel_arg_handler import ParamHolder
1211
from cuda.core.experimental._module import Kernel
1312
from cuda.core.experimental._stream import Stream
14-
from cuda.core.experimental._utils import CUDAError, check_or_create_options, handle_return
13+
from cuda.core.experimental._utils import CUDAError, check_or_create_options, get_binding_version, handle_return
1514

1615
# TODO: revisit this treatment for py313t builds
1716
_inited = False
@@ -25,7 +24,7 @@ def _lazy_init():
2524

2625
global _use_ex
2726
# binding availability depends on cuda-python version
28-
_py_major_minor = tuple(int(v) for v in (importlib.metadata.version("cuda-python").split(".")[:2]))
27+
_py_major_minor = get_binding_version()
2928
_driver_ver = handle_return(cuda.cuDriverGetVersion())
3029
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
3130
_inited = True

cuda_core/cuda/core/experimental/_module.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5-
import importlib.metadata
65

76
from cuda import cuda
8-
from cuda.core.experimental._utils import handle_return, precondition
7+
from cuda.core.experimental._utils import get_binding_version, handle_return, precondition
98

109
_backend = {
1110
"old": {
@@ -30,7 +29,7 @@ def _lazy_init():
3029

3130
global _py_major_ver, _driver_ver, _kernel_ctypes
3231
# binding availability depends on cuda-python version
33-
_py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
32+
_py_major_ver, _ = get_binding_version()
3433
if _py_major_ver >= 12:
3534
_backend["new"] = {
3635
"file": cuda.cuLibraryLoadFromFile,

cuda_core/cuda/core/experimental/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

55
import functools
6+
import importlib.metadata
67
from collections import namedtuple
78
from typing import Callable, Dict
89

@@ -134,3 +135,11 @@ def get_device_from_ctx(ctx_handle) -> int:
134135
assert ctx_handle == handle_return(cuda.cuCtxPopCurrent())
135136
handle_return(cuda.cuCtxPushCurrent(prev_ctx))
136137
return device_id
138+
139+
140+
def get_binding_version():
141+
try:
142+
major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
143+
except importlib.metadata.PackageNotFoundError:
144+
major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
145+
return tuple(int(v) for v in major_minor)

0 commit comments

Comments
 (0)