Skip to content

Commit 732b654

Browse files
author
Diptorup Deb
committed
Use enum values for kernel arg types if dpctl >= 0.17
1 parent be68c49 commit 732b654

File tree

1 file changed

+51
-28
lines changed

1 file changed

+51
-28
lines changed

numba_dpex/dpctl_iface/_helpers.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,62 @@
44

55
from numba.core import types
66

7+
from numba_dpex import dpctl_sem_version
8+
79

810
def numba_type_to_dpctl_typenum(context, ty):
911
"""
1012
This function looks up the dpctl defined enum values from
1113
``DPCTLKernelArgType``.
1214
"""
1315

14-
val = None
15-
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
16-
# DPCTL_LONG_LONG
17-
val = context.get_constant(types.int32, 9)
18-
elif ty == types.uint32:
19-
# DPCTL_UNSIGNED_LONG_LONG
20-
val = context.get_constant(types.int32, 10)
21-
elif ty == types.boolean:
22-
# DPCTL_UNSIGNED_INT
23-
val = context.get_constant(types.int32, 5)
24-
elif ty == types.int64:
25-
# DPCTL_LONG_LONG
26-
val = context.get_constant(types.int32, 9)
27-
elif ty == types.uint64:
28-
# DPCTL_SIZE_T
29-
val = context.get_constant(types.int32, 11)
30-
elif ty == types.float32:
31-
# DPCTL_FLOAT
32-
val = context.get_constant(types.int32, 12)
33-
elif ty == types.float64:
34-
# DPCTL_DOUBLE
35-
val = context.get_constant(types.int32, 13)
36-
elif ty == types.voidptr or isinstance(ty, types.CPointer):
37-
# DPCTL_VOID_PTR
38-
val = context.get_constant(types.int32, 15)
39-
else:
40-
raise NotImplementedError
16+
if dpctl_sem_version >= (0, 17, 0):
17+
from dpctl.enum_types import kernel_arg_type as kargty
4118

42-
return val
19+
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
20+
return context.get_constant(types.int32, kargty.dpctl_int32.value)
21+
elif ty == types.uint32:
22+
return context.get_constant(types.int32, kargty.dpctl_uint32.value)
23+
elif ty == types.boolean:
24+
return context.get_constant(types.int32, kargty.dpctl_uint32.value)
25+
elif ty == types.int64:
26+
return context.get_constant(types.int32, kargty.dpctl_int64.value)
27+
elif ty == types.uint64:
28+
return context.get_constant(types.int32, kargty.dpctl_uint64.value)
29+
elif ty == types.float32:
30+
return context.get_constant(types.int32, kargty.dpctl_float32.value)
31+
elif ty == types.float64:
32+
return context.get_constant(types.int32, kargty.dpctl_float64.value)
33+
elif ty == types.voidptr or isinstance(ty, types.CPointer):
34+
return context.get_constant(
35+
types.int32, kargty.dpctl_void_ptr.value
36+
)
37+
else:
38+
raise NotImplementedError
39+
else:
40+
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
41+
# DPCTL_LONG_LONG
42+
return context.get_constant(types.int32, 9)
43+
elif ty == types.uint32:
44+
# DPCTL_UNSIGNED_LONG_LONG
45+
return context.get_constant(types.int32, 10)
46+
elif ty == types.boolean:
47+
# DPCTL_UNSIGNED_INT
48+
return context.get_constant(types.int32, 5)
49+
elif ty == types.int64:
50+
# DPCTL_LONG_LONG
51+
return context.get_constant(types.int32, 9)
52+
elif ty == types.uint64:
53+
# DPCTL_SIZE_T
54+
return context.get_constant(types.int32, 11)
55+
elif ty == types.float32:
56+
# DPCTL_FLOAT
57+
return context.get_constant(types.int32, 12)
58+
elif ty == types.float64:
59+
# DPCTL_DOUBLE
60+
return context.get_constant(types.int32, 13)
61+
elif ty == types.voidptr or isinstance(ty, types.CPointer):
62+
# DPCTL_VOID_PTR
63+
return context.get_constant(types.int32, 15)
64+
else:
65+
raise NotImplementedError

0 commit comments

Comments
 (0)