-
Notifications
You must be signed in to change notification settings - Fork 32
Add PrivateArray kernel_api #1370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# SPDX-FileCopyrightText: 2024 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
""" | ||
Implements the SPIR-V overloads for the kernel_api.PrivateArray class. | ||
""" | ||
|
||
|
||
import llvmlite.ir as llvmir | ||
from llvmlite.ir.builder import IRBuilder | ||
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype | ||
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape | ||
from numba.core.typing.templates import Signature | ||
from numba.extending import intrinsic, overload | ||
|
||
from numba_dpex.core.types import USMNdArray | ||
from numba_dpex.experimental.target import DpexExpKernelTypingContext | ||
from numba_dpex.kernel_api import PrivateArray | ||
from numba_dpex.kernel_api_impl.spirv.arrayobj import ( | ||
make_spirv_generic_array_on_stack, | ||
require_literal, | ||
) | ||
from numba_dpex.utils import address_space as AddressSpace | ||
|
||
from ..target import DPEX_KERNEL_EXP_TARGET_NAME | ||
|
||
|
||
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) | ||
def _intrinsic_private_array_ctor( | ||
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument | ||
): | ||
require_literal(ty_shape) | ||
|
||
ty_array = USMNdArray( | ||
dtype=_ty_parse_dtype(ty_dtype), | ||
ndim=_ty_parse_shape(ty_shape), | ||
layout="C", | ||
addrspace=AddressSpace.PRIVATE, | ||
) | ||
|
||
sig = ty_array(ty_shape, ty_dtype) | ||
|
||
def codegen( | ||
context: DpexExpKernelTypingContext, | ||
builder: IRBuilder, | ||
sig: Signature, | ||
args: list[llvmir.Value], | ||
): | ||
shape = args[0] | ||
ty_shape = sig.args[0] | ||
ty_array = sig.return_type | ||
|
||
ary = make_spirv_generic_array_on_stack( | ||
context, builder, ty_array, ty_shape, shape | ||
) | ||
return ary._getvalue() # pylint: disable=protected-access | ||
|
||
return ( | ||
sig, | ||
codegen, | ||
) | ||
|
||
|
||
@overload( | ||
PrivateArray, | ||
prefer_literal=True, | ||
target=DPEX_KERNEL_EXP_TARGET_NAME, | ||
) | ||
def ol_private_array_ctor( | ||
shape, | ||
dtype, | ||
): | ||
"""Overload of the constructor for the class | ||
class:`numba_dpex.kernel_api.PrivateArray`. | ||
|
||
Raises: | ||
errors.TypingError: If the shape argument is not a shape compatible | ||
type. | ||
errors.TypingError: If the dtype argument is not a dtype compatible | ||
type. | ||
""" | ||
|
||
def ol_private_array_ctor_impl( | ||
shape, | ||
dtype, | ||
): | ||
# pylint: disable=no-value-for-parameter | ||
return _intrinsic_private_array_ctor(shape, dtype) | ||
|
||
return ol_private_array_ctor_impl |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: 2024 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Implements a simple array intended to be used inside kernel work item. | ||
Implementation is intended to be used in pure Python code when prototyping a | ||
kernel function. | ||
""" | ||
|
||
from numpy import ndarray | ||
|
||
|
||
class PrivateArray: | ||
""" | ||
The ``PrivateArray`` class is an simple version of array intended to be used | ||
inside kernel work item. | ||
""" | ||
|
||
def __init__(self, shape, dtype) -> None: | ||
"""Creates a new PrivateArray instance of the given shape and dtype.""" | ||
|
||
self._data = ndarray(shape=shape, dtype=dtype) | ||
|
||
def __getitem__(self, idx_obj): | ||
"""Returns the value stored at the position represented by idx_obj in | ||
the self._data ndarray. | ||
""" | ||
|
||
return self._data[idx_obj] | ||
|
||
def __setitem__(self, idx_obj, val): | ||
"""Assigns a new value to the position represented by idx_obj in | ||
the self._data ndarray. | ||
""" | ||
|
||
self._data[idx_obj] = val |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# SPDX-FileCopyrightText: 2024 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Contains SPIR-V specific array functions.""" | ||
|
||
import operator | ||
from functools import reduce | ||
from typing import Union | ||
|
||
import llvmlite.ir as llvmir | ||
from llvmlite.ir.builder import IRBuilder | ||
from numba.core import cgutils, errors, types | ||
from numba.core.base import BaseContext | ||
|
||
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext | ||
from numba_dpex.ocl.oclimpl import _get_target_data | ||
|
||
|
||
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array): | ||
""" | ||
Return the item size for the given array or buffer type. | ||
Same as numba.np.arrayobj.get_itemsize, but using spirv data. | ||
""" | ||
targetdata = _get_target_data(context) | ||
lldtype = context.get_data_type(array_type.dtype) | ||
return lldtype.get_abi_size(targetdata) | ||
|
||
|
||
def require_literal(literal_type: types.Type): | ||
"""Checks if the numba type is Literal. If iterable object is passed, | ||
checks that every element is Literal. | ||
|
||
Raises: | ||
TypingError: When argument is not Iterable. | ||
""" | ||
if not hasattr(literal_type, "__len__"): | ||
if not isinstance(literal_type, types.Literal): | ||
raise errors.TypingError("requires literal type") | ||
return | ||
|
||
for i, _ in enumerate(literal_type): | ||
if not isinstance(literal_type[i], types.Literal): | ||
raise errors.TypingError("requires literal type") | ||
|
||
|
||
def make_spirv_array( # pylint: disable=too-many-arguments | ||
context: SPIRVTargetContext, | ||
builder: IRBuilder, | ||
ty_array: types.Array, | ||
ty_shape: Union[types.IntegerLiteral, types.BaseTuple], | ||
shape: llvmir.Value, | ||
data: llvmir.Value, | ||
): | ||
"""Makes SPIR-V array and fills it data.""" | ||
# Create array object | ||
ary = context.make_array(ty_array)(context, builder) | ||
|
||
itemsize = get_itemsize(context, ty_array) | ||
ll_itemsize = cgutils.intp_t(itemsize) | ||
|
||
if isinstance(ty_shape, types.BaseTuple): | ||
shapes = cgutils.unpack_tuple(builder, shape) | ||
else: | ||
ty_shape = (ty_shape,) | ||
shapes = (shape,) | ||
shapes = [ | ||
context.cast(builder, value, fromty, types.intp) | ||
for fromty, value in zip(ty_shape, shapes) | ||
] | ||
|
||
off = ll_itemsize | ||
strides = [] | ||
if ty_array.layout == "F": | ||
for s in shapes: | ||
strides.append(off) | ||
off = builder.mul(off, s) | ||
else: | ||
for s in reversed(shapes): | ||
strides.append(off) | ||
off = builder.mul(off, s) | ||
strides.reverse() | ||
|
||
context.populate_array( | ||
ary, | ||
data=data, | ||
shape=shapes, | ||
strides=strides, | ||
itemsize=ll_itemsize, | ||
) | ||
|
||
return ary | ||
|
||
|
||
def allocate_array_data_on_stack( | ||
context: BaseContext, | ||
builder: IRBuilder, | ||
ty_array: types.Array, | ||
ty_shape: Union[types.IntegerLiteral, types.BaseTuple], | ||
): | ||
"""Allocates flat array of given shape on the stack.""" | ||
if not isinstance(ty_shape, types.BaseTuple): | ||
ty_shape = (ty_shape,) | ||
|
||
return cgutils.alloca_once( | ||
diptorupd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
builder, | ||
context.get_data_type(ty_array.dtype), | ||
size=reduce(operator.mul, [s.literal_value for s in ty_shape]), | ||
) | ||
|
||
|
||
def make_spirv_generic_array_on_stack( | ||
context: SPIRVTargetContext, | ||
builder: IRBuilder, | ||
ty_array: types.Array, | ||
ty_shape: Union[types.IntegerLiteral, types.BaseTuple], | ||
shape: llvmir.Value, | ||
): | ||
"""Makes SPIR-V array of given shape with empty data.""" | ||
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape) | ||
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# SPDX-FileCopyrightText: 2024 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import dpnp | ||
import numpy as np | ||
import pytest | ||
|
||
import numba_dpex.experimental as dpex_exp | ||
from numba_dpex.kernel_api import Item, PrivateArray, Range | ||
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel | ||
|
||
|
||
def private_array_kernel(item: Item, a): | ||
i = item.get_linear_id() | ||
p = PrivateArray(10, a.dtype) | ||
|
||
for j in range(10): | ||
p[j] = j * j | ||
|
||
a[i] = 0 | ||
for j in range(10): | ||
a[i] += p[j] | ||
|
||
|
||
def private_2d_array_kernel(item: Item, a): | ||
i = item.get_linear_id() | ||
p = PrivateArray(shape=(5, 2), dtype=a.dtype) | ||
|
||
for j in range(10): | ||
p[j % 5, j // 5] = j * j | ||
|
||
a[i] = 0 | ||
for j in range(10): | ||
a[i] += p[j % 5, j // 5] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"kernel", [private_array_kernel, private_2d_array_kernel] | ||
) | ||
@pytest.mark.parametrize( | ||
"call_kernel, decorator", | ||
[(dpex_exp.call_kernel, dpex_exp.kernel), (kapi_call_kernel, lambda a: a)], | ||
) | ||
def test_private_array(call_kernel, decorator, kernel): | ||
kernel = decorator(kernel) | ||
|
||
a = dpnp.empty(10, dtype=dpnp.float32) | ||
call_kernel(kernel, Range(a.size), a) | ||
|
||
# sum of squares from 1 to n: n*(n+1)*(2*n+1)/6 | ||
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32) | ||
|
||
assert np.array_equal(want, a.asnumpy()) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.