Skip to content

Commit 83a08a4

Browse files
authored
Merge d77a69f into d06277f
2 parents d06277f + d77a69f commit 83a08a4

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"indices",
6868
"ix_",
6969
"mask_indices",
70+
"ndindex",
7071
"nonzero",
7172
"place",
7273
"put",
@@ -1059,6 +1060,57 @@ def mask_indices(
10591060
return nonzero(a != 0)
10601061

10611062

1063+
# pylint: disable=invalid-name
1064+
ndindex = numpy.ndindex
1065+
ndindex.__doc__ = """
1066+
An N-dimensional iterator object to index arrays.
1067+
1068+
Given the shape of an array, an `ndindex` instance iterates over the
1069+
N-dimensional index of the array. At each iteration a tuple of indices is
1070+
returned, the last dimension is iterated over first.
1071+
1072+
For full documentation refer to :obj:`numpy.ndindex`.
1073+
1074+
Parameters
1075+
----------
1076+
shape : ints, or a single tuple of ints
1077+
The size of each dimension of the array can be passed as individual
1078+
parameters or as the elements of a tuple.
1079+
1080+
See Also
1081+
--------
1082+
:obj:`dpnp.ndenumerate` : Multidimensional index iterator.
1083+
:obj:`dpnp.flatiter` : Flat iterator object to iterate over arrays.
1084+
1085+
Examples
1086+
--------
1087+
>>> import dpnp as np
1088+
1089+
Dimensions as individual arguments
1090+
1091+
>>> for index in np.ndindex(3, 2, 1):
1092+
... print(index)
1093+
(0, 0, 0)
1094+
(0, 1, 0)
1095+
(1, 0, 0)
1096+
(1, 1, 0)
1097+
(2, 0, 0)
1098+
(2, 1, 0)
1099+
1100+
Same dimensions - but in a tuple ``(3, 2, 1)``
1101+
1102+
>>> for index in np.ndindex((3, 2, 1)):
1103+
... print(index)
1104+
(0, 0, 0)
1105+
(0, 1, 0)
1106+
(1, 0, 0)
1107+
(1, 1, 0)
1108+
(2, 0, 0)
1109+
(2, 1, 0)
1110+
1111+
"""
1112+
1113+
10621114
def nonzero(a):
10631115
"""
10641116
Return the indices of the elements that are non-zero.

dpnp/dpnp_utils/dpnp_utils_pad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
642642

643643
# compute indices for the iteration axes, and append a trailing
644644
# ellipsis to prevent 0d arrays decaying to scalars
645-
# TODO: replace with dpnp.ndindex when implemented
646-
inds = numpy.ndindex(view.shape[:-1])
645+
inds = dpnp.ndindex(view.shape[:-1])
647646
inds = (ind + (Ellipsis,) for ind in inds)
648647
for ind in inds:
649648
function(view[ind], pad_width[axis], axis, kwargs)

tests/test_indexing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,18 @@ def test_ix_error(self, xp, shape):
335335
assert_raises(ValueError, xp.ix_, xp.ones(shape))
336336

337337

338+
class TestNdindex:
339+
@pytest.mark.parametrize(
340+
"shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]]
341+
)
342+
def test_basic(self, shape):
343+
result = dpnp.ndindex(*shape)
344+
expected = numpy.ndindex(*shape)
345+
346+
for x, y in zip(result, expected):
347+
assert x == y
348+
349+
338350
class TestNonzero:
339351
@pytest.mark.parametrize("list_val", [[], [0], [1]])
340352
def test_trivial(self, list_val):

0 commit comments

Comments
 (0)