Skip to content

Commit eeb3145

Browse files
author
Diptorup Deb
committed
Update the interger values for kernel arg type if dpctl >= 0.17
1 parent be68c49 commit eeb3145

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

numba_dpex/dpctl_iface/_helpers.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,61 @@
44

55
from numba.core import types
66

7+
from numba_dpex import dpctl_sem_version
8+
79

810
def numba_type_to_dpctl_typenum(context, ty):
911
"""
1012
This function looks up the dpctl defined enum values from
1113
``DPCTLKernelArgType``.
1214
"""
1315

14-
val = None
15-
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
16-
# DPCTL_LONG_LONG
17-
val = context.get_constant(types.int32, 9)
18-
elif ty == types.uint32:
19-
# DPCTL_UNSIGNED_LONG_LONG
20-
val = context.get_constant(types.int32, 10)
21-
elif ty == types.boolean:
22-
# DPCTL_UNSIGNED_INT
23-
val = context.get_constant(types.int32, 5)
24-
elif ty == types.int64:
25-
# DPCTL_LONG_LONG
26-
val = context.get_constant(types.int32, 9)
27-
elif ty == types.uint64:
28-
# DPCTL_SIZE_T
29-
val = context.get_constant(types.int32, 11)
30-
elif ty == types.float32:
31-
# DPCTL_FLOAT
32-
val = context.get_constant(types.int32, 12)
33-
elif ty == types.float64:
34-
# DPCTL_DOUBLE
35-
val = context.get_constant(types.int32, 13)
36-
elif ty == types.voidptr or isinstance(ty, types.CPointer):
37-
# DPCTL_VOID_PTR
38-
val = context.get_constant(types.int32, 15)
39-
else:
40-
raise NotImplementedError
16+
if dpctl_sem_version >= (0, 17, 0):
17+
# FIXME change to imports from a dpctl enum/class rather than
18+
# hard coding these numbers.
4119

42-
return val
20+
if ty == types.boolean:
21+
return context.get_constant(types.int32, 1)
22+
elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
23+
return context.get_constant(types.int32, 4)
24+
elif ty == types.uint32:
25+
return context.get_constant(types.int32, 5)
26+
elif ty == types.int64:
27+
return context.get_constant(types.int32, 6)
28+
elif ty == types.uint64:
29+
return context.get_constant(types.int32, 7)
30+
elif ty == types.float32:
31+
return context.get_constant(types.int32, 8)
32+
elif ty == types.float64:
33+
return context.get_constant(types.int32, 9)
34+
elif ty == types.voidptr or isinstance(ty, types.CPointer):
35+
return context.get_constant(types.int32, 10)
36+
else:
37+
raise NotImplementedError
38+
else:
39+
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
40+
# DPCTL_LONG_LONG
41+
return context.get_constant(types.int32, 9)
42+
elif ty == types.uint32:
43+
# DPCTL_UNSIGNED_LONG_LONG
44+
return context.get_constant(types.int32, 10)
45+
elif ty == types.boolean:
46+
# DPCTL_UNSIGNED_INT
47+
return context.get_constant(types.int32, 5)
48+
elif ty == types.int64:
49+
# DPCTL_LONG_LONG
50+
return context.get_constant(types.int32, 9)
51+
elif ty == types.uint64:
52+
# DPCTL_SIZE_T
53+
return context.get_constant(types.int32, 11)
54+
elif ty == types.float32:
55+
# DPCTL_FLOAT
56+
return context.get_constant(types.int32, 12)
57+
elif ty == types.float64:
58+
# DPCTL_DOUBLE
59+
return context.get_constant(types.int32, 13)
60+
elif ty == types.voidptr or isinstance(ty, types.CPointer):
61+
# DPCTL_VOID_PTR
62+
return context.get_constant(types.int32, 15)
63+
else:
64+
raise NotImplementedError

0 commit comments

Comments
 (0)