Skip to content

Add support of dpnp.ndindex class #2157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"""


import numpy
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
Expand Down Expand Up @@ -151,8 +150,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):

# compute indices for the iteration axes, and append a trailing ellipsis to
# prevent 0d arrays decaying to scalars
# TODO: replace with dpnp.ndindex
inds = numpy.ndindex(inarr_view.shape[:-1])
inds = dpnp.ndindex(inarr_view.shape[:-1])
inds = (ind + (Ellipsis,) for ind in inds)

# invoke the function on the first item
Expand Down
72 changes: 72 additions & 0 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"indices",
"ix_",
"mask_indices",
"ndindex",
"nonzero",
"place",
"put",
Expand Down Expand Up @@ -1057,6 +1058,77 @@ def mask_indices(
return nonzero(a != 0)


# pylint: disable=invalid-name
# pylint: disable=too-few-public-methods
class ndindex:
"""
An N-dimensional iterator object to index arrays.

Given the shape of an array, an :obj:`dpnp.ndindex` instance iterates over
the N-dimensional index of the array. At each iteration a tuple of indices
is returned, the last dimension is iterated over first.

For full documentation refer to :obj:`numpy.ndindex`.

Parameters
----------
shape : ints, or a single tuple of ints
The size of each dimension of the array can be passed as individual
parameters or as the elements of a tuple.

See Also
--------
:obj:`dpnp.ndenumerate` : Multidimensional index iterator.
:obj:`dpnp.flatiter` : Flat iterator object to iterate over arrays.

Examples
--------
>>> import dpnp as np

Dimensions as individual arguments

>>> for index in np.ndindex(3, 2, 1):
... print(index)
(0, 0, 0)
(0, 1, 0)
(1, 0, 0)
(1, 1, 0)
(2, 0, 0)
(2, 1, 0)

Same dimensions - but in a tuple ``(3, 2, 1)``

>>> for index in np.ndindex((3, 2, 1)):
... print(index)
(0, 0, 0)
(0, 1, 0)
(1, 0, 0)
(1, 1, 0)
(2, 0, 0)
(2, 1, 0)

"""

def __init__(self, *shape):
self.ndindex_ = numpy.ndindex(*shape)

def __iter__(self):
return self.ndindex_

def __next__(self):
"""
Standard iterator method, updates the index and returns the index tuple.

Returns
-------
val : tuple of ints
Returns a tuple containing the indices of the current iteration.

"""

return self.ndindex_.__next__()


def nonzero(a):
"""
Return the indices of the elements that are non-zero.
Expand Down
3 changes: 1 addition & 2 deletions dpnp/dpnp_utils/dpnp_utils_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):

# compute indices for the iteration axes, and append a trailing
# ellipsis to prevent 0d arrays decaying to scalars
# TODO: replace with dpnp.ndindex when implemented
inds = numpy.ndindex(view.shape[:-1])
inds = dpnp.ndindex(view.shape[:-1])
inds = (ind + (Ellipsis,) for ind in inds)
for ind in inds:
function(view[ind], pad_width[axis], axis, kwargs)
Expand Down
26 changes: 26 additions & 0 deletions dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,32 @@ def test_ix_error(self, xp, shape):
assert_raises(ValueError, xp.ix_, xp.ones(shape))


@pytest.mark.parametrize(
"shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]]
)
class TestNdindex:
def test_basic(self, shape):
result = dpnp.ndindex(*shape)
expected = numpy.ndindex(*shape)

for x, y in zip(result, expected):
assert x == y

def test_next(self, shape):
dind = dpnp.ndindex(*shape)
nind = numpy.ndindex(*shape)

while True:
try:
ditem = next(dind)
except StopIteration:
assert_raises(StopIteration, next, nind)
break # both reach ends
else:
nitem = next(nind)
assert ditem == nitem


class TestNonzero:
@pytest.mark.parametrize("list_val", [[], [0], [1]])
def test_trivial(self, list_val):
Expand Down
Loading