Skip to content

Commit 66b91d1

Browse files
author
Diptorup Deb
committed
Add numba typing infrastructure for LocalAccessor
1 parent 78470ab commit 66b91d1

File tree

4 files changed

+90
-2
lines changed

4 files changed

+90
-2
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba.core.pythonapi import unbox
6+
from numba.core.types import Array, Type
7+
from numba.np import numpy_support
8+
9+
from numba_dpex.core.types import USMNdArray
10+
from numba_dpex.utils import address_space as AddressSpace
11+
12+
13+
class LocalAccessorType(USMNdArray):
14+
"""numba-dpex internal type to represent a Python object of
15+
:class:`numba_dpex.experimental.kernel_iface.LocalAccessor`.
16+
"""
17+
18+
def __init__(self, ndim, dtype):
19+
try:
20+
if isinstance(dtype, Type):
21+
parsed_dtype = dtype
22+
else:
23+
parsed_dtype = numpy_support.from_dtype(dtype)
24+
except NotImplementedError as exc:
25+
raise ValueError(f"Unsupported array dtype: {dtype}") from exc
26+
27+
type_name = (
28+
f"LocalAccessor(dtype={parsed_dtype}, ndim={ndim}, "
29+
f"address_space={AddressSpace.LOCAL})"
30+
)
31+
32+
super().__init__(
33+
ndim=ndim,
34+
layout="C",
35+
dtype=parsed_dtype,
36+
addrspace=AddressSpace.LOCAL,
37+
name=type_name,
38+
)
39+
40+
def cast_python_value(self, args):
41+
"""The helper function is not overloaded and using it on the
42+
LocalAccessorType throws a NotImplementedError.
43+
"""
44+
raise NotImplementedError
45+
46+
47+
@unbox(LocalAccessorType)
48+
def unbox_local_accessor(typ, obj, c): # pylint: disable=unused-argument
49+
"""Unboxes a Python LocalAccessor PyObject* into a numba-dpex internal
50+
representation.
51+
52+
A LocalAccessor object is represented internally in numba-dpex with the
53+
same data model as a numpy.ndarray. It is done as a LocalAccessor object
54+
serves only as a placeholder type when passed to ``call_kernel`` and the
55+
data buffer should never be accessed inside a host-side compiled function
56+
such as ``call_kernel``.
57+
58+
When a LocalAccessor object is passed as an argument to a kernel function
59+
it uses the USMArrayDeviceModel. Doing so allows numba-dpex to correctly
60+
generate the kernel signature passing in a pointer in the local address
61+
space.
62+
"""
63+
64+
nparrobj = c.pyapi.object_getattr_string(obj, "data")
65+
nparrtype = Array(typ.dtype, typ.ndim, typ.layout, readonly=False)
66+
return c.unbox(nparrtype, nparrobj)

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from numba_dpex.core.boxing import *
1313
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
1414

15+
from . import typeof
1516
from ._kernel_dpcpp_spirv_overloads import (
1617
_atomic_fence_overloads,
1718
_atomic_ref_overloads,

numba_dpex/experimental/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from numba.core import types
1010
from numba.core.datamodel import DataModelManager, models
11-
from numba.core.datamodel.models import StructModel
11+
from numba.core.datamodel.models import ArrayModel, StructModel
1212
from numba.core.extending import register_model
1313

1414
import numba_dpex.core.datamodel.models as dpex_core_models
@@ -19,6 +19,7 @@
1919
)
2020

2121
from ..core.types.kernel_api.atomic_ref import AtomicRefType
22+
from ..core.types.kernel_api.local_accessor import LocalAccessorType
2223
from .types import KernelDispatcherType
2324

2425

@@ -60,6 +61,8 @@ def _init_exp_data_model_manager() -> DataModelManager:
6061
# Register the types and data model in the DpexExpTargetContext
6162
dmm.register(AtomicRefType, AtomicRefModel)
6263

64+
dmm.register(LocalAccessorType, dpex_core_models.USMArrayDeviceModel)
65+
6366
# Register the GroupType type
6467
dmm.register(GroupType, EmptyStructModel)
6568

@@ -85,3 +88,7 @@ def _init_exp_data_model_manager() -> DataModelManager:
8588

8689
# Register the NdItemType type
8790
register_model(NdItemType)(EmptyStructModel)
91+
92+
# The LocalAccessorType is registered with the EmptyStructModel in the default
93+
# data manager so that its attributes are not accessible inside dpjit.
94+
register_model(LocalAccessorType)(ArrayModel)

numba_dpex/experimental/typeof.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
ItemType,
1515
NdItemType,
1616
)
17-
from numba_dpex.kernel_api import AtomicRef, Group, Item, NdItem
17+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
18+
from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem
1819

1920
from ..core.types.kernel_api.atomic_ref import AtomicRefType
2021

@@ -84,3 +85,16 @@ def typeof_nditem(val: NdItem, c):
8485
instance.
8586
"""
8687
return NdItemType(val.dimensions)
88+
89+
90+
@typeof_impl.register(LocalAccessor)
91+
def typeof_local_accessor(val: LocalAccessor, c) -> LocalAccessorType:
92+
"""Returns a ``numba_dpex.experimental.dpctpp_types.LocalAccessorType``
93+
instance for a Python LocalAccessor object.
94+
Args:
95+
val (LocalAccessor): Instance of the LocalAccessor type.
96+
c : Numba typing context used for type inference.
97+
Returns: LocalAccessorType object corresponding to the LocalAccessor object.
98+
"""
99+
100+
return LocalAccessorType(ndim=val.data.ndim, dtype=val.data.dtype)

0 commit comments

Comments
 (0)