|
4 | 4 |
|
5 | 5 | from numba.core import types
|
6 | 6 |
|
| 7 | +from numba_dpex import dpctl_sem_version |
| 8 | + |
7 | 9 |
|
8 | 10 | def numba_type_to_dpctl_typenum(context, ty):
|
9 | 11 | """
|
10 | 12 | This function looks up the dpctl defined enum values from
|
11 | 13 | ``DPCTLKernelArgType``.
|
12 | 14 | """
|
13 | 15 |
|
14 |
| - val = None |
15 |
| - if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): |
16 |
| - # DPCTL_LONG_LONG |
17 |
| - val = context.get_constant(types.int32, 9) |
18 |
| - elif ty == types.uint32: |
19 |
| - # DPCTL_UNSIGNED_LONG_LONG |
20 |
| - val = context.get_constant(types.int32, 10) |
21 |
| - elif ty == types.boolean: |
22 |
| - # DPCTL_UNSIGNED_INT |
23 |
| - val = context.get_constant(types.int32, 5) |
24 |
| - elif ty == types.int64: |
25 |
| - # DPCTL_LONG_LONG |
26 |
| - val = context.get_constant(types.int32, 9) |
27 |
| - elif ty == types.uint64: |
28 |
| - # DPCTL_SIZE_T |
29 |
| - val = context.get_constant(types.int32, 11) |
30 |
| - elif ty == types.float32: |
31 |
| - # DPCTL_FLOAT |
32 |
| - val = context.get_constant(types.int32, 12) |
33 |
| - elif ty == types.float64: |
34 |
| - # DPCTL_DOUBLE |
35 |
| - val = context.get_constant(types.int32, 13) |
36 |
| - elif ty == types.voidptr or isinstance(ty, types.CPointer): |
37 |
| - # DPCTL_VOID_PTR |
38 |
| - val = context.get_constant(types.int32, 15) |
39 |
| - else: |
40 |
| - raise NotImplementedError |
| 16 | + if dpctl_sem_version >= (0, 17, 0): |
| 17 | + from dpctl.enum_types import kernel_arg_type as kargty |
41 | 18 |
|
42 |
| - return val |
| 19 | + if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): |
| 20 | + return context.get_constant(types.int32, kargty.dpctl_int32.value) |
| 21 | + elif ty == types.uint32: |
| 22 | + return context.get_constant(types.int32, kargty.dpctl_uint32.value) |
| 23 | + elif ty == types.boolean: |
| 24 | + return context.get_constant(types.int32, kargty.dpctl_uint32.value) |
| 25 | + elif ty == types.int64: |
| 26 | + return context.get_constant(types.int32, kargty.dpctl_int64.value) |
| 27 | + elif ty == types.uint64: |
| 28 | + return context.get_constant(types.int32, kargty.dpctl_uint64.value) |
| 29 | + elif ty == types.float32: |
| 30 | + return context.get_constant(types.int32, kargty.dpctl_float32.value) |
| 31 | + elif ty == types.float64: |
| 32 | + return context.get_constant(types.int32, kargty.dpctl_float64.value) |
| 33 | + elif ty == types.voidptr or isinstance(ty, types.CPointer): |
| 34 | + return context.get_constant( |
| 35 | + types.int32, kargty.dpctl_void_ptr.value |
| 36 | + ) |
| 37 | + else: |
| 38 | + raise NotImplementedError |
| 39 | + else: |
| 40 | + if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): |
| 41 | + # DPCTL_LONG_LONG |
| 42 | + return context.get_constant(types.int32, 9) |
| 43 | + elif ty == types.uint32: |
| 44 | + # DPCTL_UNSIGNED_LONG_LONG |
| 45 | + return context.get_constant(types.int32, 10) |
| 46 | + elif ty == types.boolean: |
| 47 | + # DPCTL_UNSIGNED_INT |
| 48 | + return context.get_constant(types.int32, 5) |
| 49 | + elif ty == types.int64: |
| 50 | + # DPCTL_LONG_LONG |
| 51 | + return context.get_constant(types.int32, 9) |
| 52 | + elif ty == types.uint64: |
| 53 | + # DPCTL_SIZE_T |
| 54 | + return context.get_constant(types.int32, 11) |
| 55 | + elif ty == types.float32: |
| 56 | + # DPCTL_FLOAT |
| 57 | + return context.get_constant(types.int32, 12) |
| 58 | + elif ty == types.float64: |
| 59 | + # DPCTL_DOUBLE |
| 60 | + return context.get_constant(types.int32, 13) |
| 61 | + elif ty == types.voidptr or isinstance(ty, types.CPointer): |
| 62 | + # DPCTL_VOID_PTR |
| 63 | + return context.get_constant(types.int32, 15) |
| 64 | + else: |
| 65 | + raise NotImplementedError |
0 commit comments