Skip to content

Commit 43e0148

Browse files
committed
Add shape testing for moveaxis
1 parent 63a64ab commit 43e0148

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,21 @@ def test_moveaxis(x, data):
172172
out = xp.moveaxis(x, source, destination)
173173

174174
ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
175-
# TODO: shape and values testing
175+
176+
# Shape testing
177+
_source = sh.normalize_axis(source, x.ndim)
178+
_destination = sh.normalize_axis(destination, x.ndim)
179+
180+
new_axes = [n for n in range(x.ndim) if n not in _source]
181+
182+
for dest, src in sorted(zip(_destination, _source)):
183+
new_axes.insert(dest, src)
184+
185+
expected_shape = tuple(x.shape[i] for i in new_axes)
186+
187+
ph.assert_result_shape("moveaxis", in_shapes=[x.shape],
188+
out_shape=out.shape, expected=expected_shape,
189+
kw={"source": source, "destination": destination})
176190

177191
@pytest.mark.unvectorized
178192
@given(

0 commit comments

Comments
 (0)