Skip to content

Implement take_along_axis function per Python Array API #1778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ by either integral arrays of indices or boolean mask arrays.
place
put
take
take_along_axis
10 changes: 9 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@
)
from dpctl.tensor._device import Device
from dpctl.tensor._dlpack import from_dlpack
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
from dpctl.tensor._indexing_functions import (
extract,
nonzero,
place,
put,
take,
take_along_axis,
)
from dpctl.tensor._linear_algebra_functions import (
matmul,
matrix_transpose,
Expand Down Expand Up @@ -376,4 +383,5 @@
"nextafter",
"diff",
"count_nonzero",
"take_along_axis",
]
9 changes: 7 additions & 2 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,13 +795,18 @@ def _nonzero_impl(ary):
return res


def _take_multi_index(ary, inds, p):
def _take_multi_index(ary, inds, p, mode=0):
if not isinstance(ary, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
ary_nd = ary.ndim
p = normalize_axis_index(operator.index(p), ary_nd)
mode = operator.index(mode)
if mode not in [0, 1]:
raise ValueError(
"Invalid value for mode keyword, only 0 or 1 is supported"
)
queues_ = [
ary.sycl_queue,
]
Expand Down Expand Up @@ -860,7 +865,7 @@ def _take_multi_index(ary, inds, p):
ind=inds,
dst=res,
axis_start=p,
mode=0,
mode=mode,
sycl_queue=exec_q,
depends=dep_ev,
)
Expand Down
81 changes: 80 additions & 1 deletion dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils

from ._copy_utils import _extract_impl, _nonzero_impl
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
from ._numpy_helper import normalize_axis_index


Expand Down Expand Up @@ -423,3 +423,82 @@ def nonzero(arr):
if arr.ndim == 0:
raise ValueError("Array of positive rank is expected")
return _nonzero_impl(arr)


def _range(sh_i, i, nd, q, usm_t, dt):
ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q)
ind.shape = tuple(sh_i if i == j else 1 for j in range(nd))
return ind


def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
"""
Returns elements from an array at the one-dimensional indices specified
by ``indices`` along a provided ``axis``.

Args:
x (usm_ndarray):
input array. Must be compatible with ``indices``, except for the
axis (dimension) specified by ``axis``.
indices (usm_ndarray):
array indices. Must have the same rank (i.e., number of dimensions)
as ``x``.
axis: int
axis along which to select values. If ``axis`` is negative, the
function determines the axis along which to select values by
counting from the last dimension. Default: ``-1``.
mode (str, optional):
How out-of-bounds indices will be handled. Possible values
are:

- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
negative indices.
- ``"clip"``: clips indices to (``0 <= i < n``).

Default: ``"wrap"``.

Returns:
usm_ndarray:
an array having the same data type as ``x``. The returned array has
the same rank (i.e., number of dimensions) as ``x`` and a shape
determined according to :ref:`broadcasting`, except for the axis
(dimension) specified by ``axis`` whose size must equal the size
of the corresponding axis (dimension) in ``indices``.

Note:
Treatment of the out-of-bound indices in ``indices`` array is controlled
by the value of ``mode`` keyword.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
if not isinstance(indices, dpt.usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
)
x_nd = x.ndim
if x_nd != indices.ndim:
raise ValueError(
"Number of dimensions in the first and the second "
"argument arrays must be equal"
)
pp = normalize_axis_index(operator.index(axis), x_nd)
out_usm_type = dpctl.utils.get_coerced_usm_type(
(x.usm_type, indices.usm_type)
)
exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue))
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments. "
)
mode_i = _get_indexing_mode(mode)
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
_ind = tuple(
(
indices
if i == pp
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
)
for i in range(x_nd)
)
return _take_multi_index(x, _ind, 0, mode=mode_i)
146 changes: 146 additions & 0 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,5 +1535,151 @@ def test_advanced_integer_indexing_cast_indices():
inds1 = dpt.astype(inds0, "u4")
inds2 = dpt.astype(inds0, "u8")
x = dpt.ones((3, 4, 5, 6), dtype="i4")
# test getitem
with pytest.raises(ValueError):
x[inds0, inds1, inds2, ...]
# test setitem
with pytest.raises(ValueError):
x[inds0, inds1, inds2, ...] = 1


def test_take_along_axis():
get_queue_or_skip()

n0, n1, n2 = 3, 5, 7
x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
ind_dt = dpt.__array_namespace_info__().default_dtypes(
device=x.sycl_device
)["indexing"]
ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)

y0 = dpt.take_along_axis(x, ind0, axis=0)
assert y0.shape == ind0.shape
y1 = dpt.take_along_axis(x, ind1, axis=1)
assert y1.shape == ind1.shape
y2 = dpt.take_along_axis(x, ind2, axis=2)
assert y2.shape == ind2.shape


def test_take_along_axis_validation():
# type check on the first argument
with pytest.raises(TypeError):
dpt.take_along_axis(tuple(), list())
get_queue_or_skip()
n1, n2 = 2, 5
x = dpt.ones(n1 * n2)
# type check on the second argument
with pytest.raises(TypeError):
dpt.take_along_axis(x, list())
x_dev = x.sycl_device
info_ = dpt.__array_namespace_info__()
def_dtypes = info_.default_dtypes(device=x_dev)
ind_dt = def_dtypes["indexing"]
ind = dpt.zeros(1, dtype=ind_dt)
# axis valudation
with pytest.raises(ValueError):
dpt.take_along_axis(x, ind, axis=1)
# mode validation
with pytest.raises(ValueError):
dpt.take_along_axis(x, ind, axis=0, mode="invalid")
# same array-ranks validation
with pytest.raises(ValueError):
dpt.take_along_axis(dpt.reshape(x, (n1, n2)), ind)
# check compute-follows-data
q2 = dpctl.SyclQueue(x_dev, property="enable_profiling")
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
dpt.take_along_axis(x, ind2)


def check__extract_impl_validation(fn):
x = dpt.ones(10)
ind = dpt.ones(10, dtype="?")
with pytest.raises(TypeError):
fn(list(), ind)
with pytest.raises(TypeError):
fn(x, list())
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
ind2 = dpt.ones(10, dtype="?", sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
fn(x, ind2)
with pytest.raises(ValueError):
fn(x, ind, 1)


def check__nonzero_impl_validation(fn):
with pytest.raises(TypeError):
fn(list())


def check__take_multi_index(fn):
x = dpt.ones(10)
x_dev = x.sycl_device
info_ = dpt.__array_namespace_info__()
def_dtypes = info_.default_dtypes(device=x_dev)
ind_dt = def_dtypes["indexing"]
ind = dpt.arange(10, dtype=ind_dt)
with pytest.raises(TypeError):
fn(list(), tuple(), 1)
with pytest.raises(ValueError):
fn(x, (ind,), 0, mode=2)
with pytest.raises(ValueError):
fn(x, (None,), 1)
with pytest.raises(IndexError):
fn(x, (x,), 1)
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
ind2 = dpt.arange(10, dtype=ind_dt, sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
fn(x, (ind2,), 0)
m = dpt.ones((10, 10))
ind_1 = dpt.arange(10, dtype="i8")
ind_2 = dpt.arange(10, dtype="u8")
with pytest.raises(ValueError):
fn(m, (ind_1, ind_2), 0)


def check__place_impl_validation(fn):
with pytest.raises(TypeError):
fn(list(), list(), list())
x = dpt.ones(10)
with pytest.raises(TypeError):
fn(x, list(), list())
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
mask2 = dpt.ones(10, dtype="?", sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
fn(x, mask2, 1)
x2 = dpt.ones((5, 5))
mask2 = dpt.ones((5, 5), dtype="?")
with pytest.raises(ValueError):
fn(x2, mask2, x2, axis=1)


def check__put_multi_index_validation(fn):
with pytest.raises(TypeError):
fn(list(), list(), 0, list())
x = dpt.ones(10)
inds = dpt.arange(10, dtype="i8")
vals = dpt.zeros(10)
# test inds which is not a tuple/list
fn(x, inds, 0, vals)
x2 = dpt.ones((5, 5))
ind1 = dpt.arange(5, dtype="i8")
ind2 = dpt.arange(5, dtype="u8")
with pytest.raises(ValueError):
fn(x2, (ind1, ind2), 0, x2)
with pytest.raises(TypeError):
fn(x2, (ind1, list()), 0, x2)


def test__copy_utils():
import dpctl.tensor._copy_utils as cu

get_queue_or_skip()

check__extract_impl_validation(cu._extract_impl)
check__nonzero_impl_validation(cu._nonzero_impl)
check__take_multi_index(cu._take_multi_index)
check__place_impl_validation(cu._place_impl)
check__put_multi_index_validation(cu._put_multi_index)
19 changes: 16 additions & 3 deletions dpctl/tests/test_usm_ndarray_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,24 @@ def test_argsort_axis0():
x = dpt.reshape(xf, (n, m))
idx = dpt.argsort(x, axis=0)

conseq_idx = dpt.arange(m, dtype=idx.dtype)
s = x[idx, conseq_idx[dpt.newaxis, :]]
s = dpt.take_along_axis(x, idx, axis=0)

assert dpt.all(s[:-1, :] <= s[1:, :])


def test_argsort_axis1():
get_queue_or_skip()

n, m = 200, 30
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
x = dpt.reshape(xf, (n, m))
idx = dpt.argsort(x, axis=1)

s = dpt.take_along_axis(x, idx, axis=1)

assert dpt.all(s[:, :-1] <= s[:, 1:])


def test_sort_strided():
get_queue_or_skip()

Expand All @@ -199,8 +211,9 @@ def test_argsort_strided():
x_orig = dpt.arange(100, dtype="i4")
x_flipped = dpt.flip(x_orig, axis=0)
idx = dpt.argsort(x_flipped)
s = dpt.take_along_axis(x_flipped, idx, axis=0)

assert dpt.all(x_flipped[idx] == x_orig)
assert dpt.all(s == x_orig)


def test_sort_0d_array():
Expand Down
Loading