Skip to content

Commit 4ad9e94

Browse files
committed
Fix test_roll with bespoke axis iterator
1 parent a5fd48f commit 4ad9e94

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

array_api_tests/meta/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .. import shape_helpers as sh
44
from ..test_creation_functions import frange
5+
from ..test_manipulation_functions import roll_ndindex
56
from ..test_signatures import extension_module
67

78

@@ -68,3 +69,16 @@ def test_axis_ndindex(shape, axis, expected):
6869
)
6970
def test_axes_ndindex(shape, axes, expected):
7071
assert list(sh.axes_ndindex(shape, axes)) == expected
72+
73+
74+
@pytest.mark.parametrize(
75+
"shape, shifts, axes, expected",
76+
[
77+
((1, 1), (0,), (0,), [(0, 0)]),
78+
((2, 1), (1, 1), (0, 1), [(1, 0), (0, 0)]),
79+
((2, 2), (1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]),
80+
((2, 2), (-1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]),
81+
],
82+
)
83+
def test_roll_ndindex(shape, shifts, axes, expected):
84+
assert list(roll_ndindex(shape, shifts, axes)) == expected

array_api_tests/test_manipulation_functions.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from collections import deque
3-
from typing import Iterable, Union
3+
from typing import Iterable, Iterator, Tuple, Union
44

55
import pytest
66
from hypothesis import assume, given
@@ -33,8 +33,10 @@ def assert_array_ndindex(
3333
x_indices: Iterable[Union[int, Shape]],
3434
out: Array,
3535
out_indices: Iterable[Union[int, Shape]],
36+
/,
37+
**kw,
3638
):
37-
msg_suffix = f" [{func_name}()]\n {x=}\n{out=}"
39+
msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}"
3840
for x_idx, out_idx in zip(x_indices, out_indices):
3941
msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}"
4042
msg += msg_suffix
@@ -266,7 +268,15 @@ def test_reshape(x, data):
266268
assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape))
267269

268270

269-
@pytest.mark.skip(reason="faulty test logic") # TODO
271+
def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]:
272+
assert len(shifts) == len(axes) # sanity check
273+
all_shifts = [0 for _ in shape]
274+
for s, a in zip(shifts, axes):
275+
all_shifts[a] = s
276+
for idx in sh.ndindex(shape):
277+
yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape))
278+
279+
270280
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data())
271281
def test_roll(x, data):
272282
shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE)
@@ -287,6 +297,8 @@ def test_roll(x, data):
287297

288298
out = xp.roll(x, shift, **kw)
289299

300+
kw = {"shift": shift, **kw} # for error messages
301+
290302
ph.assert_dtype("roll", x.dtype, out.dtype)
291303

292304
ph.assert_result_shape("roll", (x.shape,), out.shape)
@@ -296,18 +308,12 @@ def test_roll(x, data):
296308
indices = list(sh.ndindex(x.shape))
297309
shifted_indices = deque(indices)
298310
shifted_indices.rotate(-shift)
299-
assert_array_ndindex("roll", x, indices, out, shifted_indices)
311+
assert_array_ndindex("roll", x, indices, out, shifted_indices, **kw)
300312
else:
301-
_shift = (shift,) if isinstance(shift, int) else shift
313+
shifts = (shift,) if isinstance(shift, int) else shift
302314
axes = sh.normalise_axis(kw["axis"], x.ndim)
303-
all_indices = list(sh.ndindex(x.shape))
304-
for s, a in zip(_shift, axes):
305-
side = x.shape[a]
306-
for i in range(side):
307-
indices = [idx for idx in all_indices if idx[a] == i]
308-
shifted_indices = deque(indices)
309-
shifted_indices.rotate(-s)
310-
assert_array_ndindex("roll", x, indices, out, shifted_indices)
315+
shifted_indices = roll_ndindex(x.shape, shifts, axes)
316+
assert_array_ndindex("roll", x, sh.ndindex(x.shape), out, shifted_indices, **kw)
311317

312318

313319
@given(

0 commit comments

Comments
 (0)