Skip to content

Commit 6e03083

Browse files
committed
Add PrivateArray kernel_api
1 parent bd7fb7d commit 6e03083

File tree

6 files changed

+297
-0
lines changed

6 files changed

+297
-0
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_atomic_ref_overloads,
1818
_group_barrier_overloads,
1919
_index_space_id_overloads,
20+
_private_array_overloads,
2021
)
2122
from .decorators import device_func, kernel
2223
from .launcher import call_kernel, call_kernel_async
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
7+
"""
8+
9+
10+
import llvmlite.ir as llvmir
11+
from llvmlite.ir.builder import IRBuilder
12+
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
13+
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
14+
from numba.core.typing.templates import Signature
15+
from numba.extending import intrinsic, overload
16+
17+
from numba_dpex.core.types import USMNdArray
18+
from numba_dpex.experimental.target import DpexExpKernelTypingContext
19+
from numba_dpex.kernel_api import PrivateArray
20+
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
21+
make_spirv_generic_array_on_stack,
22+
require_literal,
23+
)
24+
from numba_dpex.utils import address_space as AddressSpace
25+
26+
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
27+
28+
29+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
30+
def _intrinsic_private_array_ctor(
31+
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
32+
):
33+
require_literal(ty_shape)
34+
35+
ty_array = USMNdArray(
36+
dtype=_ty_parse_dtype(ty_dtype),
37+
ndim=_ty_parse_shape(ty_shape),
38+
layout="C",
39+
addrspace=AddressSpace.PRIVATE,
40+
)
41+
42+
sig = ty_array(ty_shape, ty_dtype)
43+
44+
def codegen(
45+
context: DpexExpKernelTypingContext,
46+
builder: IRBuilder,
47+
sig: Signature,
48+
args: list[llvmir.Value],
49+
):
50+
shape = args[0]
51+
ty_shape = sig.args[0]
52+
ty_array = sig.return_type
53+
54+
ary = make_spirv_generic_array_on_stack(
55+
context, builder, ty_array, ty_shape, shape
56+
)
57+
return ary._getvalue() # pylint: disable=protected-access
58+
59+
return (
60+
sig,
61+
codegen,
62+
)
63+
64+
65+
@overload(
66+
PrivateArray,
67+
prefer_literal=True,
68+
target=DPEX_KERNEL_EXP_TARGET_NAME,
69+
)
70+
def ol_private_array_ctor(
71+
shape,
72+
dtype,
73+
):
74+
"""Overload of the constructor for the class
75+
class:`numba_dpex.kernel_api.PrivateArray`.
76+
77+
Raises:
78+
errors.TypingError: If the shape argument is not a shape compatible
79+
type.
80+
errors.TypingError: If the dtype argument is not a dtype compatible
81+
type.
82+
"""
83+
84+
def ol_private_array_ctor_impl(
85+
shape,
86+
dtype,
87+
):
88+
# pylint: disable=no-value-for-parameter
89+
return _intrinsic_private_array_ctor(shape, dtype)
90+
91+
return ol_private_array_ctor_impl

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .index_space_ids import Group, Item, NdItem
1616
from .launcher import call_kernel
1717
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
18+
from .private_array import PrivateArray
1819
from .ranges import NdRange, Range
1920

2021
__all__ = [
@@ -28,6 +29,7 @@
2829
"Group",
2930
"NdItem",
3031
"Item",
32+
"PrivateArray",
3133
"group_barrier",
3234
"call_kernel",
3335
]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Implements a simple array intended to be used inside kernel work item.
6+
Implementation is intended to be used in pure Python code when prototyping a
7+
kernel function.
8+
"""
9+
10+
from numpy import ndarray
11+
12+
13+
class PrivateArray:
14+
"""
15+
The ``PrivateArray`` class is an simple version of array intended to be used
16+
inside kernel work item.
17+
"""
18+
19+
def __init__(self, shape, dtype) -> None:
20+
"""Creates a new PrivateArray instance of the given shape and dtype."""
21+
22+
self._data = ndarray(shape=shape, dtype=dtype)
23+
24+
def __getitem__(self, idx_obj):
25+
"""Returns the value stored at the position represented by idx_obj in
26+
the self._data ndarray.
27+
"""
28+
29+
return self._data[idx_obj]
30+
31+
def __setitem__(self, idx_obj, val):
32+
"""Assigns a new value to the position represented by idx_obj in
33+
the self._data ndarray.
34+
"""
35+
36+
self._data[idx_obj] = val
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Contains spriv specific array functions."""
2+
3+
import operator
4+
from functools import reduce
5+
from typing import Union
6+
7+
import llvmlite.ir as llvmir
8+
from llvmlite.ir.builder import IRBuilder
9+
from numba.core import cgutils, errors, types
10+
from numba.core.base import BaseContext
11+
12+
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
13+
from numba_dpex.ocl.oclimpl import _get_target_data
14+
15+
16+
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
17+
"""
18+
Return the item size for the given array or buffer type.
19+
Same as numba.np.arrayobj.get_itemsize, but using spirv data.
20+
"""
21+
targetdata = _get_target_data(context)
22+
lldtype = context.get_data_type(array_type.dtype)
23+
return lldtype.get_abi_size(targetdata)
24+
25+
26+
def require_literal(literal_type: types.Type):
27+
"""Checks if the numba type is Literal. If iterable object is passed,
28+
checks that every element is Literal.
29+
30+
Raises:
31+
TypingError: When argument is not Iterable.
32+
"""
33+
if not hasattr(literal_type, "__len__"):
34+
if not isinstance(literal_type, types.Literal):
35+
raise errors.TypingError("requires literal type")
36+
return
37+
38+
for i, _ in enumerate(literal_type):
39+
if not isinstance(literal_type[i], types.Literal):
40+
raise errors.TypingError("requires literal type")
41+
42+
43+
def make_spirv_array( # pylint: disable=too-many-arguments
44+
context: SPIRVTargetContext,
45+
builder: IRBuilder,
46+
ty_array: types.Array,
47+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
48+
shape: llvmir.Value,
49+
data: llvmir.Value,
50+
):
51+
"""Makes SPIR-V array and fills it data."""
52+
# Create array object
53+
ary = context.make_array(ty_array)(context, builder)
54+
55+
itemsize = get_itemsize(context, ty_array)
56+
ll_itemsize = cgutils.intp_t(itemsize)
57+
58+
if isinstance(ty_shape, types.BaseTuple):
59+
shapes = cgutils.unpack_tuple(builder, shape)
60+
else:
61+
ty_shape = (ty_shape,)
62+
shapes = (shape,)
63+
shapes = [
64+
context.cast(builder, value, fromty, types.intp)
65+
for fromty, value in zip(ty_shape, shapes)
66+
]
67+
68+
off = ll_itemsize
69+
strides = []
70+
if ty_array.layout == "F":
71+
for s in shapes:
72+
strides.append(off)
73+
off = builder.mul(off, s)
74+
else:
75+
for s in reversed(shapes):
76+
strides.append(off)
77+
off = builder.mul(off, s)
78+
strides.reverse()
79+
80+
context.populate_array(
81+
ary,
82+
data=data,
83+
shape=shapes,
84+
strides=strides,
85+
itemsize=ll_itemsize,
86+
)
87+
88+
return ary
89+
90+
91+
def allocate_array_data_on_stack(
92+
context: BaseContext,
93+
builder: IRBuilder,
94+
ty_array: types.Array,
95+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
96+
):
97+
"""Allocates flat array of given shape on the stack."""
98+
if not isinstance(ty_shape, types.BaseTuple):
99+
ty_shape = (ty_shape,)
100+
101+
return cgutils.alloca_once(
102+
builder,
103+
context.get_data_type(ty_array.dtype),
104+
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
105+
)
106+
107+
108+
def make_spirv_generic_array_on_stack(
109+
context: SPIRVTargetContext,
110+
builder: IRBuilder,
111+
ty_array: types.Array,
112+
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
113+
shape: llvmir.Value,
114+
):
115+
"""Makes SPIR-V array of given shape with empty data."""
116+
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape)
117+
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import dpnp
2+
import numpy as np
3+
import pytest
4+
5+
import numba_dpex.experimental as dpex_exp
6+
from numba_dpex.kernel_api import Item, PrivateArray, Range
7+
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
8+
9+
10+
def private_array_kernel(item: Item, a):
11+
i = item.get_linear_id()
12+
p = PrivateArray(10, a.dtype)
13+
14+
for j in range(10):
15+
p[j] = j * j
16+
17+
a[i] = 0
18+
for j in range(10):
19+
a[i] += p[j]
20+
21+
22+
def private_2d_array_kernel(item: Item, a):
23+
i = item.get_linear_id()
24+
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
25+
26+
for j in range(10):
27+
p[j % 5, j // 5] = j * j
28+
29+
a[i] = 0
30+
for j in range(10):
31+
a[i] += p[j % 5, j // 5]
32+
33+
34+
@pytest.mark.parametrize(
35+
"kernel", [private_array_kernel, private_2d_array_kernel]
36+
)
37+
@pytest.mark.parametrize(
38+
"call_kernel, decorator",
39+
[(dpex_exp.call_kernel, dpex_exp.kernel), (kapi_call_kernel, lambda a: a)],
40+
)
41+
def test_private_array(call_kernel, decorator, kernel):
42+
kernel = decorator(kernel)
43+
44+
a = dpnp.empty(10, dtype=dpnp.float32)
45+
call_kernel(kernel, Range(a.size), a)
46+
47+
# sum of squares from 1 to n: n*(n+1)*(2*n+1)/6
48+
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)
49+
50+
assert np.array_equal(want, a.asnumpy())

0 commit comments

Comments
 (0)