Skip to content

Commit ffd3829

Browse files
authored
Add support of dpnp.ndindex class (#2157)
* Leverage dpnp.ndindex on numpy class * Remove TODO in the code with dpnp.ndindex * Add tests to dpnp.ndindex * Inherit ndindex from numpy class * Instantiate instance of numpy class * Remove TODO in dpnp/dpnp_iface_functional.py * Add tests for __next__() method
1 parent 81e7e29 commit ffd3829

File tree

4 files changed

+100
-5
lines changed

4 files changed

+100
-5
lines changed

dpnp/dpnp_iface_functional.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
"""
3838

3939

40-
import numpy
4140
from dpctl.tensor._numpy_helper import (
4241
normalize_axis_index,
4342
normalize_axis_tuple,
@@ -151,8 +150,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
151150

152151
# compute indices for the iteration axes, and append a trailing ellipsis to
153152
# prevent 0d arrays decaying to scalars
154-
# TODO: replace with dpnp.ndindex
155-
inds = numpy.ndindex(inarr_view.shape[:-1])
153+
inds = dpnp.ndindex(inarr_view.shape[:-1])
156154
inds = (ind + (Ellipsis,) for ind in inds)
157155

158156
# invoke the function on the first item

dpnp/dpnp_iface_indexing.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"indices",
6565
"ix_",
6666
"mask_indices",
67+
"ndindex",
6768
"nonzero",
6869
"place",
6970
"put",
@@ -1057,6 +1058,77 @@ def mask_indices(
10571058
return nonzero(a != 0)
10581059

10591060

1061+
# pylint: disable=invalid-name
1062+
# pylint: disable=too-few-public-methods
1063+
class ndindex:
1064+
"""
1065+
An N-dimensional iterator object to index arrays.
1066+
1067+
Given the shape of an array, an :obj:`dpnp.ndindex` instance iterates over
1068+
the N-dimensional index of the array. At each iteration a tuple of indices
1069+
is returned, the last dimension is iterated over first.
1070+
1071+
For full documentation refer to :obj:`numpy.ndindex`.
1072+
1073+
Parameters
1074+
----------
1075+
shape : ints, or a single tuple of ints
1076+
The size of each dimension of the array can be passed as individual
1077+
parameters or as the elements of a tuple.
1078+
1079+
See Also
1080+
--------
1081+
:obj:`dpnp.ndenumerate` : Multidimensional index iterator.
1082+
:obj:`dpnp.flatiter` : Flat iterator object to iterate over arrays.
1083+
1084+
Examples
1085+
--------
1086+
>>> import dpnp as np
1087+
1088+
Dimensions as individual arguments
1089+
1090+
>>> for index in np.ndindex(3, 2, 1):
1091+
... print(index)
1092+
(0, 0, 0)
1093+
(0, 1, 0)
1094+
(1, 0, 0)
1095+
(1, 1, 0)
1096+
(2, 0, 0)
1097+
(2, 1, 0)
1098+
1099+
Same dimensions - but in a tuple ``(3, 2, 1)``
1100+
1101+
>>> for index in np.ndindex((3, 2, 1)):
1102+
... print(index)
1103+
(0, 0, 0)
1104+
(0, 1, 0)
1105+
(1, 0, 0)
1106+
(1, 1, 0)
1107+
(2, 0, 0)
1108+
(2, 1, 0)
1109+
1110+
"""
1111+
1112+
def __init__(self, *shape):
1113+
self.ndindex_ = numpy.ndindex(*shape)
1114+
1115+
def __iter__(self):
1116+
return self.ndindex_
1117+
1118+
def __next__(self):
1119+
"""
1120+
Standard iterator method, updates the index and returns the index tuple.
1121+
1122+
Returns
1123+
-------
1124+
val : tuple of ints
1125+
Returns a tuple containing the indices of the current iteration.
1126+
1127+
"""
1128+
1129+
return self.ndindex_.__next__()
1130+
1131+
10601132
def nonzero(a):
10611133
"""
10621134
Return the indices of the elements that are non-zero.

dpnp/dpnp_utils/dpnp_utils_pad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
642642

643643
# compute indices for the iteration axes, and append a trailing
644644
# ellipsis to prevent 0d arrays decaying to scalars
645-
# TODO: replace with dpnp.ndindex when implemented
646-
inds = numpy.ndindex(view.shape[:-1])
645+
inds = dpnp.ndindex(view.shape[:-1])
647646
inds = (ind + (Ellipsis,) for ind in inds)
648647
for ind in inds:
649648
function(view[ind], pad_width[axis], axis, kwargs)

dpnp/tests/test_indexing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,32 @@ def test_ix_error(self, xp, shape):
335335
assert_raises(ValueError, xp.ix_, xp.ones(shape))
336336

337337

338+
@pytest.mark.parametrize(
339+
"shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]]
340+
)
341+
class TestNdindex:
342+
def test_basic(self, shape):
343+
result = dpnp.ndindex(*shape)
344+
expected = numpy.ndindex(*shape)
345+
346+
for x, y in zip(result, expected):
347+
assert x == y
348+
349+
def test_next(self, shape):
350+
dind = dpnp.ndindex(*shape)
351+
nind = numpy.ndindex(*shape)
352+
353+
while True:
354+
try:
355+
ditem = next(dind)
356+
except StopIteration:
357+
assert_raises(StopIteration, next, nind)
358+
break # both reach ends
359+
else:
360+
nitem = next(nind)
361+
assert ditem == nitem
362+
363+
338364
class TestNonzero:
339365
@pytest.mark.parametrize("list_val", [[], [0], [1]])
340366
def test_trivial(self, list_val):

0 commit comments

Comments
 (0)