|
4 | 4 |
|
5 | 5 | from numba.core import types
|
6 | 6 |
|
| 7 | +from numba_dpex import dpctl_sem_version |
| 8 | + |
7 | 9 |
|
8 | 10 | def numba_type_to_dpctl_typenum(context, ty):
|
9 | 11 | """
|
10 | 12 | This function looks up the dpctl defined enum values from
|
11 | 13 | ``DPCTLKernelArgType``.
|
12 | 14 | """
|
13 | 15 |
|
14 |
| - val = None |
15 |
| - if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): |
16 |
| - # DPCTL_LONG_LONG |
17 |
| - val = context.get_constant(types.int32, 9) |
18 |
| - elif ty == types.uint32: |
19 |
| - # DPCTL_UNSIGNED_LONG_LONG |
20 |
| - val = context.get_constant(types.int32, 10) |
21 |
| - elif ty == types.boolean: |
22 |
| - # DPCTL_UNSIGNED_INT |
23 |
| - val = context.get_constant(types.int32, 5) |
24 |
| - elif ty == types.int64: |
25 |
| - # DPCTL_LONG_LONG |
26 |
| - val = context.get_constant(types.int32, 9) |
27 |
| - elif ty == types.uint64: |
28 |
| - # DPCTL_SIZE_T |
29 |
| - val = context.get_constant(types.int32, 11) |
30 |
| - elif ty == types.float32: |
31 |
| - # DPCTL_FLOAT |
32 |
| - val = context.get_constant(types.int32, 12) |
33 |
| - elif ty == types.float64: |
34 |
| - # DPCTL_DOUBLE |
35 |
| - val = context.get_constant(types.int32, 13) |
36 |
| - elif ty == types.voidptr or isinstance(ty, types.CPointer): |
37 |
| - # DPCTL_VOID_PTR |
38 |
| - val = context.get_constant(types.int32, 15) |
39 |
| - else: |
40 |
| - raise NotImplementedError |
| 16 | + if dpctl_sem_version >= (0, 17, 0): |
| 17 | + # FIXME change to imports from a dpctl enum/class rather than |
| 18 | + # hard coding these numbers. |
41 | 19 |
|
42 |
| - return val |
| 20 | + if ty == types.boolean: |
| 21 | + return context.get_constant(types.int32, 1) |
| 22 | + elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): |
| 23 | + return context.get_constant(types.int32, 4) |
| 24 | + elif ty == types.uint32: |
| 25 | + return context.get_constant(types.int32, 5) |
| 26 | + elif ty == types.int64: |
| 27 | + return context.get_constant(types.int32, 6) |
| 28 | + elif ty == types.uint64: |
| 29 | + return context.get_constant(types.int32, 7) |
| 30 | + elif ty == types.float32: |
| 31 | + return context.get_constant(types.int32, 8) |
| 32 | + elif ty == types.float64: |
| 33 | + return context.get_constant(types.int32, 9) |
| 34 | + elif ty == types.voidptr or isinstance(ty, types.CPointer): |
| 35 | + return context.get_constant(types.int32, 10) |
| 36 | + else: |
| 37 | + raise NotImplementedError |
| 38 | + else: |
| 39 | + if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): |
| 40 | + # DPCTL_LONG_LONG |
| 41 | + return context.get_constant(types.int32, 9) |
| 42 | + elif ty == types.uint32: |
| 43 | + # DPCTL_UNSIGNED_LONG_LONG |
| 44 | + return context.get_constant(types.int32, 10) |
| 45 | + elif ty == types.boolean: |
| 46 | + # DPCTL_UNSIGNED_INT |
| 47 | + return context.get_constant(types.int32, 5) |
| 48 | + elif ty == types.int64: |
| 49 | + # DPCTL_LONG_LONG |
| 50 | + return context.get_constant(types.int32, 9) |
| 51 | + elif ty == types.uint64: |
| 52 | + # DPCTL_SIZE_T |
| 53 | + return context.get_constant(types.int32, 11) |
| 54 | + elif ty == types.float32: |
| 55 | + # DPCTL_FLOAT |
| 56 | + return context.get_constant(types.int32, 12) |
| 57 | + elif ty == types.float64: |
| 58 | + # DPCTL_DOUBLE |
| 59 | + return context.get_constant(types.int32, 13) |
| 60 | + elif ty == types.voidptr or isinstance(ty, types.CPointer): |
| 61 | + # DPCTL_VOID_PTR |
| 62 | + return context.get_constant(types.int32, 15) |
| 63 | + else: |
| 64 | + raise NotImplementedError |
0 commit comments