Skip to content

Commit cf35681

Browse files
committed
fix binding version check
1 parent b647dfc commit cf35681

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from cuda.core.experimental._kernel_arg_handler import ParamHolder
1212
from cuda.core.experimental._module import Kernel
1313
from cuda.core.experimental._stream import Stream
14-
from cuda.core.experimental._utils import CUDAError, check_or_create_options, handle_return
14+
from cuda.core.experimental._utils import CUDAError, check_or_create_options, get_binding_version, handle_return
1515

1616
# TODO: revisit this treatment for py313t builds
1717
_inited = False
@@ -25,7 +25,7 @@ def _lazy_init():
2525

2626
global _use_ex
2727
# binding availability depends on cuda-python version
28-
_py_major_minor = tuple(int(v) for v in (importlib.metadata.version("cuda-python").split(".")[:2]))
28+
_py_major_minor = get_binding_version()
2929
_driver_ver = handle_return(cuda.cuDriverGetVersion())
3030
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
3131
_inited = True

cuda_core/cuda/core/experimental/_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import importlib.metadata
66

77
from cuda import cuda
8-
from cuda.core.experimental._utils import handle_return, precondition
8+
from cuda.core.experimental._utils import get_binding_version, handle_return, precondition
99

1010
_backend = {
1111
"old": {
@@ -30,7 +30,7 @@ def _lazy_init():
3030

3131
global _py_major_ver, _driver_ver, _kernel_ctypes
3232
# binding availability depends on cuda-python version
33-
_py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
33+
_py_major_ver, _ = get_binding_version()
3434
if _py_major_ver >= 12:
3535
_backend["new"] = {
3636
"file": cuda.cuLibraryLoadFromFile,

cuda_core/cuda/core/experimental/_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,11 @@ def get_device_from_ctx(ctx_handle) -> int:
134134
assert ctx_handle == handle_return(cuda.cuCtxPopCurrent())
135135
handle_return(cuda.cuCtxPushCurrent(prev_ctx))
136136
return device_id
137+
138+
139+
def get_binding_version():
140+
try:
141+
major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
142+
except importlib.metadata.PackageNotFoundError:
143+
major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
144+
return tuple(int(v) for v in major_minor)

0 commit comments

Comments
 (0)