@@ -14,25 +14,26 @@ def numba_type_to_dpctl_typenum(context, ty):
14
14
"""
15
15
16
16
if dpctl_sem_version >= (0 , 17 , 0 ):
17
- # FIXME change to imports from a dpctl enum/class rather than
18
- # hard coding these numbers.
17
+ from dpctl ._sycl_queue import kernel_arg_type as kargty
19
18
20
19
if ty == types .boolean :
21
- return context .get_constant (types .int32 , 1 )
20
+ return context .get_constant (types .int32 , kargty . dpctl_uint8 . value )
22
21
elif ty == types .int32 or isinstance (ty , types .scalars .IntegerLiteral ):
23
- return context .get_constant (types .int32 , 4 )
22
+ return context .get_constant (types .int32 , kargty . dpctl_int32 . value )
24
23
elif ty == types .uint32 :
25
- return context .get_constant (types .int32 , 5 )
24
+ return context .get_constant (types .int32 , kargty . dpctl_uint32 . value )
26
25
elif ty == types .int64 :
27
- return context .get_constant (types .int32 , 6 )
26
+ return context .get_constant (types .int32 , kargty . dpctl_int64 . value )
28
27
elif ty == types .uint64 :
29
- return context .get_constant (types .int32 , 7 )
28
+ return context .get_constant (types .int32 , kargty . dpctl_uint64 . value )
30
29
elif ty == types .float32 :
31
- return context .get_constant (types .int32 , 8 )
30
+ return context .get_constant (types .int32 , kargty . dpctl_float32 . value )
32
31
elif ty == types .float64 :
33
- return context .get_constant (types .int32 , 9 )
32
+ return context .get_constant (types .int32 , kargty . dpctl_float64 . value )
34
33
elif ty == types .voidptr or isinstance (ty , types .CPointer ):
35
- return context .get_constant (types .int32 , 10 )
34
+ return context .get_constant (
35
+ types .int32 , kargty .dpctl_void_ptr .value
36
+ )
36
37
else :
37
38
raise NotImplementedError
38
39
else :
0 commit comments