Skip to content

Commit b319731

Browse files
committed
fix module load for cuda-python 11.x
1 parent 397b7c7 commit b319731

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,31 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import importlib.metadata
6+
57
from cuda import cuda, cudart
68
from cuda.core.experimental._utils import handle_return
79

810

911
_backend = {
10-
"new": {
11-
"file": cuda.cuLibraryLoadFromFile,
12-
"data": cuda.cuLibraryLoadData,
13-
"kernel": cuda.cuLibraryGetKernel,
14-
},
1512
"old": {
1613
"file": cuda.cuModuleLoad,
1714
"data": cuda.cuModuleLoadDataEx,
1815
"kernel": cuda.cuModuleGetFunction,
1916
},
2017
}
18+
_kernel_ctypes = [cuda.CUfunction]
19+
20+
# binding availability depends on cuda-python version
21+
py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
22+
if py_major_ver >= 12:
23+
_backend["new"] = {
24+
"file": cuda.cuLibraryLoadFromFile,
25+
"data": cuda.cuLibraryLoadData,
26+
"kernel": cuda.cuLibraryGetKernel,
27+
}
28+
_kernel_ctypes.append(cuda.CUkernel)
29+
_kernel_ctypes = tuple(_kernel_ctypes)
2130

2231

2332
class Kernel:
@@ -29,7 +38,7 @@ def __init__(self):
2938

3039
@staticmethod
3140
def _from_obj(obj, mod):
32-
assert isinstance(obj, (cuda.CUkernel, cuda.CUfunction))
41+
assert isinstance(obj, _kernel_ctypes)
3342
assert isinstance(mod, ObjectCode)
3443
ker = Kernel.__new__(Kernel)
3544
ker._handle = obj
@@ -49,7 +58,10 @@ def __init__(self, module, code_type, jit_options=None, *,
4958
self._handle = None
5059

5160
driver_ver = handle_return(cuda.cuDriverGetVersion())
52-
self._loader = _backend["new"] if driver_ver >= 12000 else _backend["old"]
61+
if py_major_ver >= 12 and driver_ver >= 12000:
62+
self._loader = _backend["new"]
63+
else:
64+
self._loader = _backend["old"]
5365

5466
if isinstance(module, str):
5567
if driver_ver < 12000 and jit_options is not None:
@@ -65,7 +77,7 @@ def __init__(self, module, code_type, jit_options=None, *,
6577
# TODO: support library options
6678
[], [], 0)
6779
else:
68-
args = (module, len(jit_options), jit_options.keys(), jit_options.values())
80+
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
6981
self._handle = handle_return(self._loader["data"](*args))
7082

7183
self._code_type = code_type

0 commit comments

Comments
 (0)