32
32
33
33
class Kernel :
34
34
35
- __slots__ = ("_handle" , "_module" ,)
35
+ __slots__ = ("_handle" , "_module" , "_backend" )
36
36
37
37
def __init__ (self ):
38
38
raise NotImplementedError ("directly constructing a Kernel instance is not supported" )
39
39
40
40
@staticmethod
41
- def _from_obj (obj , mod ):
41
+ def _from_obj (obj , mod , backend ):
42
42
assert isinstance (obj , _kernel_ctypes )
43
43
assert isinstance (mod , ObjectCode )
44
44
ker = Kernel .__new__ (Kernel )
45
45
ker ._handle = obj
46
46
ker ._module = mod
47
+ ker ._backend = backend
47
48
return ker
48
49
49
50
# TODO: implement from_handle()
50
51
51
52
52
53
class ObjectCode :
53
54
54
- __slots__ = ("_handle" , "_code_type" , "_module" , "_loader" , "_sym_map" )
55
+ __slots__ = ("_handle" , "_code_type" , "_module" , "_loader" , "_loader_backend" , " _sym_map" )
55
56
_supported_code_type = ("cubin" , "ptx" , "fatbin" )
56
57
57
58
def __init__ (self , module , code_type , jit_options = None , * ,
@@ -62,6 +63,7 @@ def __init__(self, module, code_type, jit_options=None, *,
62
63
63
64
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000 ) else "old"
64
65
self ._loader = _backend [backend ]
66
+ self ._loader_backend = backend
65
67
66
68
if isinstance (module , str ):
67
69
if driver_ver < 12000 and jit_options is not None :
@@ -94,6 +96,6 @@ def get_kernel(self, name):
94
96
except KeyError :
95
97
name = name .encode ()
96
98
data = handle_return (self ._loader ["kernel" ](self ._handle , name ))
97
- return Kernel ._from_obj (data , self )
99
+ return Kernel ._from_obj (data , self , self . _loader_backend )
98
100
99
101
# TODO: implement from_handle()
0 commit comments