Skip to content

Use enum values for kernel arg types if dpctl >= 0.17 #1379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions numba_dpex/dpctl_iface/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,61 @@

from numba.core import types

from numba_dpex import dpctl_sem_version


def numba_type_to_dpctl_typenum(context, ty):
"""
This function looks up the dpctl defined enum values from
``DPCTLKernelArgType``.
"""

val = None
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
# DPCTL_LONG_LONG
val = context.get_constant(types.int32, 9)
elif ty == types.uint32:
# DPCTL_UNSIGNED_LONG_LONG
val = context.get_constant(types.int32, 10)
elif ty == types.boolean:
# DPCTL_UNSIGNED_INT
val = context.get_constant(types.int32, 5)
elif ty == types.int64:
# DPCTL_LONG_LONG
val = context.get_constant(types.int32, 9)
elif ty == types.uint64:
# DPCTL_SIZE_T
val = context.get_constant(types.int32, 11)
elif ty == types.float32:
# DPCTL_FLOAT
val = context.get_constant(types.int32, 12)
elif ty == types.float64:
# DPCTL_DOUBLE
val = context.get_constant(types.int32, 13)
elif ty == types.voidptr or isinstance(ty, types.CPointer):
# DPCTL_VOID_PTR
val = context.get_constant(types.int32, 15)
else:
raise NotImplementedError
if dpctl_sem_version >= (0, 17, 0):
# FIXME change to imports from a dpctl enum/class rather than
# hard coding these numbers.

return val
if ty == types.boolean:
return context.get_constant(types.int32, 1)
elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
return context.get_constant(types.int32, 4)
elif ty == types.uint32:
return context.get_constant(types.int32, 5)
elif ty == types.int64:
return context.get_constant(types.int32, 6)
elif ty == types.uint64:
return context.get_constant(types.int32, 7)
elif ty == types.float32:
return context.get_constant(types.int32, 8)
elif ty == types.float64:
return context.get_constant(types.int32, 9)
elif ty == types.voidptr or isinstance(ty, types.CPointer):
return context.get_constant(types.int32, 10)
else:
raise NotImplementedError
else:
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
# DPCTL_LONG_LONG
return context.get_constant(types.int32, 9)
elif ty == types.uint32:
# DPCTL_UNSIGNED_LONG_LONG
return context.get_constant(types.int32, 10)
elif ty == types.boolean:
# DPCTL_UNSIGNED_INT
return context.get_constant(types.int32, 5)
elif ty == types.int64:
# DPCTL_LONG_LONG
return context.get_constant(types.int32, 9)
elif ty == types.uint64:
# DPCTL_SIZE_T
return context.get_constant(types.int32, 11)
elif ty == types.float32:
# DPCTL_FLOAT
return context.get_constant(types.int32, 12)
elif ty == types.float64:
# DPCTL_DOUBLE
return context.get_constant(types.int32, 13)
elif ty == types.voidptr or isinstance(ty, types.CPointer):
# DPCTL_VOID_PTR
return context.get_constant(types.int32, 15)
else:
raise NotImplementedError