15
15
"kernel" : cuda .cuModuleGetFunction ,
16
16
},
17
17
}
18
- _kernel_ctypes = [cuda .CUfunction ]
19
18
20
19
# binding availability depends on cuda-python version
21
20
py_major_ver = int (importlib .metadata .version ("cuda-python" ).split ("." )[0 ])
25
24
"data" : cuda .cuLibraryLoadData ,
26
25
"kernel" : cuda .cuLibraryGetKernel ,
27
26
}
28
- _kernel_ctypes .append (cuda .CUkernel )
29
- _kernel_ctypes = tuple (_kernel_ctypes )
27
+ _kernel_ctypes = (cuda .CUfunction , cuda .CUkernel )
28
+ else :
29
+ _kernel_ctypes = (cuda .CUfunction ,)
30
+ driver_ver = handle_return (cuda .cuDriverGetVersion ())
30
31
31
32
32
33
class Kernel :
@@ -45,6 +46,8 @@ def _from_obj(obj, mod):
45
46
ker ._module = mod
46
47
return ker
47
48
49
+ # TODO: implement from_handle()
50
+
48
51
49
52
class ObjectCode :
50
53
@@ -57,11 +60,8 @@ def __init__(self, module, code_type, jit_options=None, *,
57
60
raise ValueError
58
61
self ._handle = None
59
62
60
- driver_ver = handle_return (cuda .cuDriverGetVersion ())
61
- if py_major_ver >= 12 and driver_ver >= 12000 :
62
- self ._loader = _backend ["new" ]
63
- else :
64
- self ._loader = _backend ["old" ]
63
+ backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000 ) else "old"
64
+ self ._loader = _backend [backend ]
65
65
66
66
if isinstance (module , str ):
67
67
if driver_ver < 12000 and jit_options is not None :
@@ -72,11 +72,11 @@ def __init__(self, module, code_type, jit_options=None, *,
72
72
assert isinstance (module , bytes )
73
73
if jit_options is None :
74
74
jit_options = {}
75
- if driver_ver >= 12000 :
75
+ if backend == "new" :
76
76
args = (module , list (jit_options .keys ()), list (jit_options .values ()), len (jit_options ),
77
77
# TODO: support library options
78
78
[], [], 0 )
79
- else :
79
+ else : # "old" backend
80
80
args = (module , len (jit_options ), list (jit_options .keys ()), list (jit_options .values ()))
81
81
self ._handle = handle_return (self ._loader ["data" ](* args ))
82
82
@@ -95,3 +95,5 @@ def get_kernel(self, name):
95
95
name = name .encode ()
96
96
data = handle_return (self ._loader ["kernel" ](self ._handle , name ))
97
97
return Kernel ._from_obj (data , self )
98
+
99
+ # TODO: implement from_handle()
0 commit comments