Skip to content

Commit b550f42

Browse files
committed
Add PrivateArray kernel_api
1 parent bd7fb7d commit b550f42

File tree

6 files changed

+305
-0
lines changed

6 files changed

+305
-0
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_atomic_ref_overloads,
1818
_group_barrier_overloads,
1919
_index_space_id_overloads,
20+
_private_array_overloads,
2021
)
2122
from .decorators import device_func, kernel
2223
from .launcher import call_kernel, call_kernel_async
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
7+
"""
8+
9+
10+
import llvmlite.ir as llvmir
11+
from llvmlite.ir.builder import IRBuilder
12+
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
13+
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
14+
from numba.core.typing.templates import Signature
15+
from numba.extending import intrinsic, overload
16+
17+
from numba_dpex.core.types import USMNdArray
18+
from numba_dpex.experimental.target import DpexExpKernelTypingContext
19+
from numba_dpex.kernel_api import PrivateArray
20+
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
21+
make_spirv_generic_array_on_stack,
22+
require_literal,
23+
)
24+
from numba_dpex.utils import address_space as AddressSpace
25+
26+
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
27+
28+
29+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
30+
def _intrinsic_private_array_ctor(
31+
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
32+
):
33+
require_literal(ty_shape)
34+
35+
ty_array = USMNdArray(
36+
dtype=_ty_parse_dtype(ty_dtype),
37+
ndim=_ty_parse_shape(ty_shape),
38+
layout="C",
39+
addrspace=AddressSpace.PRIVATE,
40+
)
41+
42+
sig = ty_array(ty_shape, ty_dtype)
43+
44+
def codegen(
45+
context: DpexExpKernelTypingContext,
46+
builder: IRBuilder,
47+
sig: Signature,
48+
args: list[llvmir.Value],
49+
):
50+
shape = args[0]
51+
ty_shape = sig.args[0]
52+
ty_array = sig.return_type
53+
54+
ary = make_spirv_generic_array_on_stack(
55+
context, builder, ty_array, ty_shape, shape
56+
)
57+
return ary._getvalue() # pylint: disable=protected-access
58+
59+
return (
60+
sig,
61+
codegen,
62+
)
63+
64+
65+
@overload(
66+
PrivateArray,
67+
prefer_literal=True,
68+
target=DPEX_KERNEL_EXP_TARGET_NAME,
69+
)
70+
def ol_private_array_ctor(
71+
shape,
72+
dtype,
73+
):
74+
"""Overload of the constructor for the class
75+
class:`numba_dpex.kernel_api.PrivateArray`.
76+
77+
Raises:
78+
errors.TypingError: If the shape argument is not a shape compatible
79+
type.
80+
errors.TypingError: If the dtype argument is not a dtype compatible
81+
type.
82+
"""
83+
84+
def ol_private_array_ctor_impl(
85+
shape,
86+
dtype,
87+
):
88+
# pylint: disable=no-value-for-parameter
89+
return _intrinsic_private_array_ctor(shape, dtype)
90+
91+
return ol_private_array_ctor_impl

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .index_space_ids import Group, Item, NdItem
1616
from .launcher import call_kernel
1717
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
18+
from .private_array import PrivateArray
1819
from .ranges import NdRange, Range
1920

2021
__all__ = [
@@ -28,6 +29,7 @@
2829
"Group",
2930
"NdItem",
3031
"Item",
32+
"PrivateArray",
3133
"group_barrier",
3234
"call_kernel",
3335
]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Implements a simple array intended to be used inside kernel work item.
6+
Implementation is intended to be used in pure Python code when prototyping a
7+
kernel function.
8+
"""
9+
10+
from numpy import ndarray
11+
12+
13+
class PrivateArray:
14+
"""
15+
The ``PrivateArray`` class is an simple version of array intended to be used
16+
inside kernel work item.
17+
"""
18+
19+
def __init__(self, shape, dtype) -> None:
20+
"""Creates a new PrivateArray instance of the given shape and dtype."""
21+
22+
self._data = ndarray(shape=shape, dtype=dtype)
23+
24+
def __getitem__(self, idx_obj):
25+
"""Returns the value stored at the position represented by idx_obj in
26+
the self._data ndarray.
27+
"""
28+
29+
return self._data[idx_obj]
30+
31+
def __setitem__(self, idx_obj, val):
32+
"""Assigns a new value to the position represented by idx_obj in
33+
the self._data ndarray.
34+
"""
35+
36+
self._data[idx_obj] = val
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Contains SPIR-V specific array functions."""
6+
7+
import operator
8+
from functools import reduce
9+
from typing import Union
10+
11+
import llvmlite.ir as llvmir
12+
from llvmlite.ir.builder import IRBuilder
13+
from numba.core import cgutils, errors, types
14+
from numba.core.base import BaseContext
15+
16+
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
17+
from numba_dpex.ocl.oclimpl import _get_target_data
18+
19+
20+
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
21+
"""
22+
Return the item size for the given array or buffer type.
23+
Same as numba.np.arrayobj.get_itemsize, but using spirv data.
24+
"""
25+
targetdata = _get_target_data(context)
26+
lldtype = context.get_data_type(array_type.dtype)
27+
return lldtype.get_abi_size(targetdata)
28+
29+
30+
def require_literal(literal_type: types.Type):
31+
"""Checks if the numba type is Literal. If iterable object is passed,
32+
checks that every element is Literal.
33+
34+
Raises:
35+
TypingError: When argument is not Iterable.
36+
"""
37+
if not hasattr(literal_type, "__len__"):
38+
if not isinstance(literal_type, types.Literal):
39+
raise errors.TypingError("requires literal type")
40+
return
41+
42+
for i, _ in enumerate(literal_type):
43+
if not isinstance(literal_type[i], types.Literal):
44+
raise errors.TypingError("requires literal type")
45+
46+
47+
def make_spirv_array( # pylint: disable=too-many-arguments
48+
context: SPIRVTargetContext,
49+
builder: IRBuilder,
50+
ty_array: types.Array,
51+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
52+
shape: llvmir.Value,
53+
data: llvmir.Value,
54+
):
55+
"""Makes SPIR-V array and fills it data."""
56+
# Create array object
57+
ary = context.make_array(ty_array)(context, builder)
58+
59+
itemsize = get_itemsize(context, ty_array)
60+
ll_itemsize = cgutils.intp_t(itemsize)
61+
62+
if isinstance(ty_shape, types.BaseTuple):
63+
shapes = cgutils.unpack_tuple(builder, shape)
64+
else:
65+
ty_shape = (ty_shape,)
66+
shapes = (shape,)
67+
shapes = [
68+
context.cast(builder, value, fromty, types.intp)
69+
for fromty, value in zip(ty_shape, shapes)
70+
]
71+
72+
off = ll_itemsize
73+
strides = []
74+
if ty_array.layout == "F":
75+
for s in shapes:
76+
strides.append(off)
77+
off = builder.mul(off, s)
78+
else:
79+
for s in reversed(shapes):
80+
strides.append(off)
81+
off = builder.mul(off, s)
82+
strides.reverse()
83+
84+
context.populate_array(
85+
ary,
86+
data=data,
87+
shape=shapes,
88+
strides=strides,
89+
itemsize=ll_itemsize,
90+
)
91+
92+
return ary
93+
94+
95+
def allocate_array_data_on_stack(
96+
context: BaseContext,
97+
builder: IRBuilder,
98+
ty_array: types.Array,
99+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
100+
):
101+
"""Allocates flat array of given shape on the stack."""
102+
if not isinstance(ty_shape, types.BaseTuple):
103+
ty_shape = (ty_shape,)
104+
105+
return cgutils.alloca_once(
106+
builder,
107+
context.get_data_type(ty_array.dtype),
108+
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
109+
)
110+
111+
112+
def make_spirv_generic_array_on_stack(
113+
context: SPIRVTargetContext,
114+
builder: IRBuilder,
115+
ty_array: types.Array,
116+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
117+
shape: llvmir.Value,
118+
):
119+
"""Makes SPIR-V array of given shape with empty data."""
120+
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape)
121+
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
import numpy as np
7+
import pytest
8+
9+
import numba_dpex.experimental as dpex_exp
10+
from numba_dpex.kernel_api import Item, PrivateArray, Range
11+
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
12+
13+
14+
def private_array_kernel(item: Item, a):
15+
i = item.get_linear_id()
16+
p = PrivateArray(10, a.dtype)
17+
18+
for j in range(10):
19+
p[j] = j * j
20+
21+
a[i] = 0
22+
for j in range(10):
23+
a[i] += p[j]
24+
25+
26+
def private_2d_array_kernel(item: Item, a):
27+
i = item.get_linear_id()
28+
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
29+
30+
for j in range(10):
31+
p[j % 5, j // 5] = j * j
32+
33+
a[i] = 0
34+
for j in range(10):
35+
a[i] += p[j % 5, j // 5]
36+
37+
38+
@pytest.mark.parametrize(
39+
"kernel", [private_array_kernel, private_2d_array_kernel]
40+
)
41+
@pytest.mark.parametrize(
42+
"call_kernel, decorator",
43+
[(dpex_exp.call_kernel, dpex_exp.kernel), (kapi_call_kernel, lambda a: a)],
44+
)
45+
def test_private_array(call_kernel, decorator, kernel):
46+
kernel = decorator(kernel)
47+
48+
a = dpnp.empty(10, dtype=dpnp.float32)
49+
call_kernel(kernel, Range(a.size), a)
50+
51+
# sum of squares from 1 to n: n*(n+1)*(2*n+1)/6
52+
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)
53+
54+
assert np.array_equal(want, a.asnumpy())

0 commit comments

Comments
 (0)