2
2
#
3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
+ import importlib .metadata
6
+
5
7
from cuda import cuda , cudart
6
8
from cuda .core .experimental ._utils import handle_return
7
9
8
10
9
11
_backend = {
10
- "new" : {
11
- "file" : cuda .cuLibraryLoadFromFile ,
12
- "data" : cuda .cuLibraryLoadData ,
13
- "kernel" : cuda .cuLibraryGetKernel ,
14
- },
15
12
"old" : {
16
13
"file" : cuda .cuModuleLoad ,
17
14
"data" : cuda .cuModuleLoadDataEx ,
18
15
"kernel" : cuda .cuModuleGetFunction ,
19
16
},
20
17
}
18
+ _kernel_ctypes = [cuda .CUfunction ]
19
+
20
+ # binding availability depends on cuda-python version
21
+ py_major_ver = int (importlib .metadata .version ("cuda-python" ).split ("." )[0 ])
22
+ if py_major_ver >= 12 :
23
+ _backend ["new" ] = {
24
+ "file" : cuda .cuLibraryLoadFromFile ,
25
+ "data" : cuda .cuLibraryLoadData ,
26
+ "kernel" : cuda .cuLibraryGetKernel ,
27
+ }
28
+ _kernel_ctypes .append (cuda .CUkernel )
29
+ _kernel_ctypes = tuple (_kernel_ctypes )
21
30
22
31
23
32
class Kernel :
@@ -29,7 +38,7 @@ def __init__(self):
29
38
30
39
@staticmethod
31
40
def _from_obj (obj , mod ):
32
- assert isinstance (obj , ( cuda . CUkernel , cuda . CUfunction ) )
41
+ assert isinstance (obj , _kernel_ctypes )
33
42
assert isinstance (mod , ObjectCode )
34
43
ker = Kernel .__new__ (Kernel )
35
44
ker ._handle = obj
@@ -49,7 +58,10 @@ def __init__(self, module, code_type, jit_options=None, *,
49
58
self ._handle = None
50
59
51
60
driver_ver = handle_return (cuda .cuDriverGetVersion ())
52
- self ._loader = _backend ["new" ] if driver_ver >= 12000 else _backend ["old" ]
61
+ if py_major_ver >= 12 and driver_ver >= 12000 :
62
+ self ._loader = _backend ["new" ]
63
+ else :
64
+ self ._loader = _backend ["old" ]
53
65
54
66
if isinstance (module , str ):
55
67
if driver_ver < 12000 and jit_options is not None :
@@ -65,7 +77,7 @@ def __init__(self, module, code_type, jit_options=None, *,
65
77
# TODO: support library options
66
78
[], [], 0 )
67
79
else :
68
- args = (module , len (jit_options ), jit_options .keys (), jit_options .values ())
80
+ args = (module , len (jit_options ), list ( jit_options .keys ()), list ( jit_options .values () ))
69
81
self ._handle = handle_return (self ._loader ["data" ](* args ))
70
82
71
83
self ._code_type = code_type
0 commit comments