Skip to content

Commit ade28e8

Browse files
committed
Add PrivateArray kernel_api
1 parent 4b1c8a9 commit ade28e8

File tree

6 files changed

+312
-0
lines changed

6 files changed

+312
-0
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_atomic_ref_overloads,
1818
_group_barrier_overloads,
1919
_index_space_id_overloads,
20+
_private_array_overloads,
2021
)
2122
from .decorators import device_func, kernel
2223
from .launcher import call_kernel, call_kernel_async
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements the SPIR-V overloads for the kernel_api.AtomicRef class methods.
7+
"""
8+
9+
10+
import llvmlite.ir as llvmir
11+
from llvmlite.ir.builder import IRBuilder
12+
from numba.core import errors, types
13+
from numba.core.typing.templates import Signature
14+
from numba.extending import intrinsic, overload
15+
16+
from numba_dpex.core.types import USMNdArray
17+
from numba_dpex.experimental.target import DpexExpKernelTypingContext
18+
from numba_dpex.kernel_api import PrivateArray
19+
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
20+
make_spirv_generic_array_on_stack,
21+
require_literal,
22+
)
23+
from numba_dpex.utils import address_space as AddressSpace
24+
25+
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
26+
27+
28+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
29+
def _intrinsic_private_array_ctor(
30+
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
31+
):
32+
require_literal(ty_shape)
33+
34+
if not isinstance(ty_dtype, types.DType):
35+
raise errors.TypingError("Second argument must be instance of DType")
36+
37+
ndim = 1
38+
if hasattr(ty_shape, "__len__"):
39+
ndim = len(ty_shape)
40+
41+
ty_array = USMNdArray(
42+
dtype=ty_dtype.dtype,
43+
ndim=ndim,
44+
layout="C",
45+
addrspace=AddressSpace.PRIVATE,
46+
)
47+
48+
sig = ty_array(ty_shape, ty_dtype)
49+
50+
def codegen(
51+
context: DpexExpKernelTypingContext,
52+
builder: IRBuilder,
53+
sig: Signature,
54+
args: list[llvmir.Value],
55+
):
56+
shape = args[0]
57+
ty_shape = sig.args[0]
58+
ty_array = sig.return_type
59+
60+
ary = make_spirv_generic_array_on_stack(
61+
context, builder, ty_array, ty_shape, shape
62+
)
63+
return ary._getvalue() # pylint: disable=protected-access
64+
65+
return (
66+
sig,
67+
codegen,
68+
)
69+
70+
71+
@overload(
72+
PrivateArray,
73+
prefer_literal=True,
74+
target=DPEX_KERNEL_EXP_TARGET_NAME,
75+
)
76+
def ol_private_array_ctor(
77+
shape,
78+
dtype,
79+
):
80+
"""Overload of the constructor for the class
81+
class:`numba_dpex.kernel_api.AtomicRef`.
82+
83+
Raises:
84+
errors.TypingError: If the `ref` argument is not a UsmNdArray type.
85+
errors.TypingError: If the dtype of the `ref` is not supported in an
86+
AtomicRef.
87+
errors.TypingError: If the device does not support atomic operations on
88+
the dtype of the `ref`.
89+
errors.TypingError: If the `memory_order`, `address_type`, or
90+
`memory_scope` arguments could not be parsed as integer literals.
91+
errors.TypingError: If the `address_space` argument is different from
92+
the address space attribute of the `ref` argument.
93+
errors.TypingError: If the address space is PRIVATE.
94+
95+
"""
96+
97+
def ol_private_array_ctor_impl(
98+
shape,
99+
dtype,
100+
):
101+
# pylint: disable=no-value-for-parameter
102+
return _intrinsic_private_array_ctor(shape, dtype)
103+
104+
return ol_private_array_ctor_impl

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .index_space_ids import Group, Item, NdItem
1616
from .launcher import call_kernel
1717
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
18+
from .private_array import PrivateArray
1819
from .ranges import NdRange, Range
1920

2021
__all__ = [
@@ -28,6 +29,7 @@
2829
"Group",
2930
"NdItem",
3031
"Item",
32+
"PrivateArray",
3133
"group_barrier",
3234
"call_kernel",
3335
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Implements a Python analogue to SYCL's local_accessor class. The class is
6+
intended to be used in pure Python code when prototyping a kernel function
7+
and to be passed to an actual kernel function for local memory allocation.
8+
"""
9+
10+
from dpctl.tensor import usm_ndarray
11+
12+
KernelUseOnlyError = NotImplementedError("Only for use inside kernel")
13+
14+
15+
class PrivateArray(usm_ndarray):
16+
"""
17+
The ``LocalAccessor`` class is analogous to SYCL's ``local_accessor``
18+
class. The class acts a s proxy to allocating device local memory and
19+
accessing that memory from within a :func:`numba_dpex.kernel` decorated
20+
function.
21+
"""
22+
23+
def __init__(self, shape, dtype) -> None:
24+
"""Creates a new LocalAccessor instance of the given shape and dtype."""
25+
26+
raise KernelUseOnlyError
27+
28+
def __getitem__(self, idx_obj):
29+
"""Returns the value stored at the position represented by idx_obj in
30+
the self._data ndarray.
31+
"""
32+
33+
raise KernelUseOnlyError
34+
35+
def __setitem__(self, idx_obj, val):
36+
"""Assigns a new value to the position represented by idx_obj in
37+
the self._data ndarray.
38+
"""
39+
40+
raise KernelUseOnlyError
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Contains spriv specific array functions."""
2+
3+
import operator
4+
from functools import reduce
5+
from typing import Union
6+
7+
import llvmlite.ir as llvmir
8+
from llvmlite.ir.builder import IRBuilder
9+
from numba.core import cgutils, errors, types
10+
from numba.core.base import BaseContext
11+
12+
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
13+
from numba_dpex.ocl.oclimpl import _get_target_data
14+
15+
16+
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
17+
"""
18+
Return the item size for the given array or buffer type.
19+
Same as numba.np.arrayobj.get_itemsize, but using spirv data.
20+
"""
21+
targetdata = _get_target_data(context)
22+
lldtype = context.get_data_type(array_type.dtype)
23+
return lldtype.get_abi_size(targetdata)
24+
25+
26+
def require_literal(literal_type: types.Type):
27+
"""Checks if the numba type is Literal. If iterable object is passed,
28+
checks that every element is Literal.
29+
30+
Raises:
31+
TypingError: When argument is not Iterable.
32+
"""
33+
if not hasattr(literal_type, "__len__"):
34+
if not isinstance(literal_type, types.Literal):
35+
raise errors.TypingError("requires literal type")
36+
return
37+
38+
for i, _ in enumerate(literal_type):
39+
if not isinstance(literal_type[i], types.Literal):
40+
raise errors.TypingError("requires literal type")
41+
42+
43+
def make_spirv_array( # pylint: disable=too-many-arguments
44+
context: SPIRVTargetContext,
45+
builder: IRBuilder,
46+
ty_array: types.Array,
47+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
48+
shape: llvmir.Value,
49+
data: llvmir.Value,
50+
):
51+
"""Makes SPIR-V array and fills it data."""
52+
# Create array object
53+
ary = context.make_array(ty_array)(context, builder)
54+
55+
itemsize = get_itemsize(context, ty_array)
56+
ll_itemsize = cgutils.intp_t(itemsize)
57+
58+
if isinstance(ty_shape, types.BaseTuple):
59+
shapes = cgutils.unpack_tuple(builder, shape)
60+
else:
61+
ty_shape = (ty_shape,)
62+
shapes = (shape,)
63+
shapes = [
64+
context.cast(builder, value, fromty, types.intp)
65+
for fromty, value in zip(ty_shape, shapes)
66+
]
67+
68+
off = ll_itemsize
69+
strides = []
70+
if ty_array.layout == "F":
71+
for s in shapes:
72+
strides.append(off)
73+
off = builder.mul(off, s)
74+
else:
75+
for s in reversed(shapes):
76+
strides.append(off)
77+
off = builder.mul(off, s)
78+
strides.reverse()
79+
80+
context.populate_array(
81+
ary,
82+
data=data,
83+
shape=shapes,
84+
strides=strides,
85+
itemsize=ll_itemsize,
86+
)
87+
88+
return ary
89+
90+
91+
def allocate_array_data_on_stack(
92+
context: BaseContext,
93+
builder: IRBuilder,
94+
ty_array: types.Array,
95+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
96+
):
97+
"""Allocates flat array of given shape on the stack."""
98+
if not isinstance(ty_shape, types.BaseTuple):
99+
ty_shape = (ty_shape,)
100+
101+
return cgutils.alloca_once(
102+
builder,
103+
context.get_data_type(ty_array.dtype),
104+
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
105+
)
106+
107+
108+
def make_spirv_generic_array_on_stack(
109+
context: SPIRVTargetContext,
110+
builder: IRBuilder,
111+
ty_array: types.Array,
112+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
113+
shape: llvmir.Value,
114+
):
115+
"""Makes SPIR-V array of given shape with empty data."""
116+
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape)
117+
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import dpnp
2+
import numpy as np
3+
import pytest
4+
5+
import numba_dpex.experimental as dpex_exp
6+
from numba_dpex.kernel_api import Item, PrivateArray, Range
7+
8+
9+
@dpex_exp.kernel
10+
def private_array_kernel(item: Item, a):
11+
i = item.get_linear_id()
12+
p = PrivateArray(10, a.dtype)
13+
14+
for j in range(10):
15+
p[j] = j * j
16+
17+
a[i] = 0
18+
for j in range(10):
19+
a[i] += p[j]
20+
21+
22+
@dpex_exp.kernel
23+
def private_2d_array_kernel(item: Item, a):
24+
i = item.get_linear_id()
25+
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
26+
27+
for j in range(10):
28+
p[j % 5, j // 5] = j * j
29+
30+
a[i] = 0
31+
for j in range(10):
32+
a[i] += p[j % 5, j // 5]
33+
34+
35+
@pytest.mark.parametrize(
36+
"kernel", [private_array_kernel, private_2d_array_kernel]
37+
)
38+
def test_private_array(kernel):
39+
a = dpnp.empty(10, dtype=dpnp.float32)
40+
dpex_exp.call_kernel(kernel, Range(a.size), a)
41+
42+
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)
43+
44+
assert np.array_equal(want, a.asnumpy())
45+
46+
47+
if __name__ == "__main__":
48+
test_private_array(private_array_kernel)

0 commit comments

Comments
 (0)