Skip to content

Commit b64f337

Browse files
committed
simplify & fix module/library handling
1 parent b319731 commit b64f337

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"kernel": cuda.cuModuleGetFunction,
1616
},
1717
}
18-
_kernel_ctypes = [cuda.CUfunction]
1918

2019
# binding availability depends on cuda-python version
2120
py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
@@ -25,8 +24,10 @@
2524
"data": cuda.cuLibraryLoadData,
2625
"kernel": cuda.cuLibraryGetKernel,
2726
}
28-
_kernel_ctypes.append(cuda.CUkernel)
29-
_kernel_ctypes = tuple(_kernel_ctypes)
27+
_kernel_ctypes = (cuda.CUfunction, cuda.CUkernel)
28+
else:
29+
_kernel_ctypes = (cuda.CUfunction,)
30+
driver_ver = handle_return(cuda.cuDriverGetVersion())
3031

3132

3233
class Kernel:
@@ -45,6 +46,8 @@ def _from_obj(obj, mod):
4546
ker._module = mod
4647
return ker
4748

49+
# TODO: implement from_handle()
50+
4851

4952
class ObjectCode:
5053

@@ -57,11 +60,8 @@ def __init__(self, module, code_type, jit_options=None, *,
5760
raise ValueError
5861
self._handle = None
5962

60-
driver_ver = handle_return(cuda.cuDriverGetVersion())
61-
if py_major_ver >= 12 and driver_ver >= 12000:
62-
self._loader = _backend["new"]
63-
else:
64-
self._loader = _backend["old"]
63+
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
64+
self._loader = _backend[backend]
6565

6666
if isinstance(module, str):
6767
if driver_ver < 12000 and jit_options is not None:
@@ -72,11 +72,11 @@ def __init__(self, module, code_type, jit_options=None, *,
7272
assert isinstance(module, bytes)
7373
if jit_options is None:
7474
jit_options = {}
75-
if driver_ver >= 12000:
75+
if backend == "new":
7676
args = (module, list(jit_options.keys()), list(jit_options.values()), len(jit_options),
7777
# TODO: support library options
7878
[], [], 0)
79-
else:
79+
else: # "old" backend
8080
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
8181
self._handle = handle_return(self._loader["data"](*args))
8282

@@ -95,3 +95,5 @@ def get_kernel(self, name):
9595
name = name.encode()
9696
data = handle_return(self._loader["kernel"](self._handle, name))
9797
return Kernel._from_obj(data, self)
98+
99+
# TODO: implement from_handle()

0 commit comments

Comments
 (0)