Skip to content

Commit c2c22a5

Browse files
committed
Add PrivateArray kernel_api
1 parent 4b1c8a9 commit c2c22a5

File tree

5 files changed

+239
-0
lines changed

5 files changed

+239
-0
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_atomic_ref_overloads,
1818
_group_barrier_overloads,
1919
_index_space_id_overloads,
20+
_private_array_overloads,
2021
)
2122
from .decorators import device_func, kernel
2223
from .launcher import call_kernel, call_kernel_async
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements the SPIR-V overloads for the kernel_api.AtomicRef class methods.
7+
"""
8+
9+
import operator
10+
from functools import reduce
11+
12+
from numba.core import cgutils, errors, types
13+
from numba.core.cpu import CPUContext
14+
from numba.extending import intrinsic, overload
15+
from numba.np.arrayobj import get_itemsize as get_itemsize_np
16+
17+
from numba_dpex.core.types import USMNdArray
18+
from numba_dpex.kernel_api import PrivateArray
19+
from numba_dpex.ocl.oclimpl import _get_target_data
20+
from numba_dpex.utils import address_space as AddressSpace
21+
22+
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
23+
24+
25+
def get_itemsize_spirv(context, array_type):
26+
"""
27+
Return the item size for the given array or buffer type.
28+
"""
29+
targetdata = _get_target_data(context)
30+
lldtype = context.get_data_type(array_type.dtype)
31+
return lldtype.get_abi_size(targetdata)
32+
33+
34+
def require_literal(literal_type):
35+
if not hasattr(literal_type, "__len__"):
36+
if not isinstance(literal_type, types.Literal):
37+
raise errors.TypingError("requires literal type")
38+
return
39+
40+
for i in range(len(literal_type)):
41+
if not isinstance(literal_type[i], types.Literal):
42+
raise errors.TypingError("requires literal type")
43+
44+
45+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
46+
def _intrinsic_private_array_ctor(ty_context, ty_shape, ty_dtype):
47+
require_literal(ty_shape)
48+
49+
if not isinstance(ty_dtype, types.DType):
50+
raise errors.TypingError("Second argument must be instance of DType")
51+
52+
ndim = 1
53+
if hasattr(ty_shape, "__len__"):
54+
ndim = len(ty_shape)
55+
56+
ty_array = USMNdArray(
57+
dtype=ty_dtype.dtype,
58+
ndim=ndim,
59+
layout="C",
60+
addrspace=AddressSpace.PRIVATE,
61+
)
62+
63+
sig = ty_array(ty_shape, ty_dtype)
64+
65+
def codegen(context: CPUContext, builder, sig, args):
66+
shape = args[0]
67+
ty_shape = sig.args[0]
68+
ty_array = sig.return_type
69+
# Create array object
70+
ary = context.make_array(ty_array)(context, builder)
71+
72+
itemsize = get_itemsize_spirv(context, ty_array)
73+
ll_itemsize = cgutils.intp_t(itemsize)
74+
75+
if isinstance(ty_shape, types.BaseTuple):
76+
shapes = cgutils.unpack_tuple(builder, shape)
77+
else:
78+
ty_shape = (ty_shape,)
79+
shapes = (shape,)
80+
shapes = [
81+
context.cast(builder, value, fromty, types.intp)
82+
for fromty, value in zip(ty_shape, shapes)
83+
]
84+
85+
off = ll_itemsize
86+
strides = []
87+
if ty_array.layout == "F":
88+
for s in shapes:
89+
strides.append(off)
90+
off = builder.mul(off, s)
91+
else:
92+
for s in reversed(shapes):
93+
strides.append(off)
94+
off = builder.mul(off, s)
95+
strides.reverse()
96+
97+
dataptr = cgutils.alloca_once(
98+
builder,
99+
context.get_data_type(ty_array.dtype),
100+
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
101+
)
102+
103+
context.populate_array(
104+
ary,
105+
data=dataptr,
106+
shape=shapes,
107+
strides=strides,
108+
itemsize=ll_itemsize,
109+
)
110+
111+
return ary._getvalue()
112+
113+
return (
114+
sig,
115+
codegen,
116+
)
117+
118+
119+
@overload(
120+
PrivateArray,
121+
prefer_literal=True,
122+
target=DPEX_KERNEL_EXP_TARGET_NAME,
123+
)
124+
def ol_private_array_ctor(
125+
shape,
126+
dtype,
127+
):
128+
"""Overload of the constructor for the class
129+
class:`numba_dpex.kernel_api.AtomicRef`.
130+
131+
Raises:
132+
errors.TypingError: If the `ref` argument is not a UsmNdArray type.
133+
errors.TypingError: If the dtype of the `ref` is not supported in an
134+
AtomicRef.
135+
errors.TypingError: If the device does not support atomic operations on
136+
the dtype of the `ref`.
137+
errors.TypingError: If the `memory_order`, `address_type`, or
138+
`memory_scope` arguments could not be parsed as integer literals.
139+
errors.TypingError: If the `address_space` argument is different from
140+
the address space attribute of the `ref` argument.
141+
errors.TypingError: If the address space is PRIVATE.
142+
143+
"""
144+
145+
def ol_private_array_ctor_impl(
146+
shape,
147+
dtype,
148+
):
149+
# pylint: disable=no-value-for-parameter
150+
return _intrinsic_private_array_ctor(shape, dtype)
151+
152+
return ol_private_array_ctor_impl

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .index_space_ids import Group, Item, NdItem
1616
from .launcher import call_kernel
1717
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
18+
from .private_array import PrivateArray
1819
from .ranges import NdRange, Range
1920

2021
__all__ = [
@@ -28,6 +29,7 @@
2829
"Group",
2930
"NdItem",
3031
"Item",
32+
"PrivateArray",
3133
"group_barrier",
3234
"call_kernel",
3335
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Implements a Python analogue to SYCL's local_accessor class. The class is
6+
intended to be used in pure Python code when prototyping a kernel function
7+
and to be passed to an actual kernel function for local memory allocation.
8+
"""
9+
10+
from dpctl.tensor import usm_ndarray
11+
12+
KernelUseOnlyError = NotImplementedError("Only for use inside kernel")
13+
14+
15+
class PrivateArray(usm_ndarray):
16+
"""
17+
The ``LocalAccessor`` class is analogous to SYCL's ``local_accessor``
18+
class. The class acts a s proxy to allocating device local memory and
19+
accessing that memory from within a :func:`numba_dpex.kernel` decorated
20+
function.
21+
"""
22+
23+
def __init__(self, shape, dtype) -> None:
24+
"""Creates a new LocalAccessor instance of the given shape and dtype."""
25+
26+
raise KernelUseOnlyError
27+
28+
def __getitem__(self, idx_obj):
29+
"""Returns the value stored at the position represented by idx_obj in
30+
the self._data ndarray.
31+
"""
32+
33+
raise KernelUseOnlyError
34+
35+
def __setitem__(self, idx_obj, val):
36+
"""Assigns a new value to the position represented by idx_obj in
37+
the self._data ndarray.
38+
"""
39+
40+
raise KernelUseOnlyError
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import dpnp
2+
import numpy as np
3+
import pytest
4+
5+
import numba_dpex.experimental as dpex_exp
6+
from numba_dpex.kernel_api import Item, PrivateArray, Range
7+
8+
9+
@dpex_exp.kernel
10+
def private_array_kernel(item: Item, a):
11+
i = item.get_linear_id()
12+
p = PrivateArray(10, a.dtype)
13+
14+
for j in range(10):
15+
p[j] = j
16+
17+
a[i] = 0
18+
for j in range(10):
19+
a[i] += p[j]
20+
21+
22+
@dpex_exp.kernel
23+
def private_2d_array_kernel(item: Item, a):
24+
i = item.get_linear_id()
25+
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
26+
27+
for j in range(10):
28+
p[j % 5, j // 5] = j
29+
30+
a[i] = 0
31+
for j in range(10):
32+
a[i] += p[j % 5, j // 5]
33+
34+
35+
@pytest.mark.parametrize(
36+
"kernel", [private_array_kernel, private_2d_array_kernel]
37+
)
38+
def test_private_array(kernel):
39+
a = dpnp.empty(10, dtype=dpnp.float32)
40+
dpex_exp.call_kernel(kernel, Range(a.size), a)
41+
42+
want = np.full(a.size, (0 + 9) * 10 / 2, dtype=np.float32)
43+
44+
assert np.array_equal(want, a.asnumpy())

0 commit comments

Comments
 (0)