Skip to content

Commit 99c9507

Browse files
committed
Add tests for __next__() method
1 parent d7119ce commit 99c9507

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

dpnp/tests/test_indexing.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,17 +335,31 @@ def test_ix_error(self, xp, shape):
335335
assert_raises(ValueError, xp.ix_, xp.ones(shape))
336336

337337

338+
@pytest.mark.parametrize(
339+
"shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]]
340+
)
338341
class TestNdindex:
339-
@pytest.mark.parametrize(
340-
"shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]]
341-
)
342342
def test_basic(self, shape):
343343
result = dpnp.ndindex(*shape)
344344
expected = numpy.ndindex(*shape)
345345

346346
for x, y in zip(result, expected):
347347
assert x == y
348348

349+
def test_next(self, shape):
350+
dind = dpnp.ndindex(*shape)
351+
nind = numpy.ndindex(*shape)
352+
353+
while True:
354+
try:
355+
ditem = next(dind)
356+
except StopIteration:
357+
assert_raises(StopIteration, next, nind)
358+
break # both reach ends
359+
else:
360+
nitem = next(nind)
361+
assert ditem == nitem
362+
349363

350364
class TestNonzero:
351365
@pytest.mark.parametrize("list_val", [[], [0], [1]])

0 commit comments

Comments
 (0)