Skip to content

Commit 529d31f

Browse files
committed
Test than nonzero on zero-dimensional array raises an exception.
This is required by the spec (see https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.nonzero.html)
1 parent d58fb6b commit 529d31f

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,17 @@ def test_argmin(x, data):
8787
ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
8888

8989

90+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=0, max_dims=0)))
91+
def test_nonzero_zerodim_error(x):
92+
with pytest.raises(Exception):
93+
xp.nonzero(x)
94+
95+
9096
@pytest.mark.data_dependent_shapes
91-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
97+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1, min_side=1)))
9298
def test_nonzero(x):
9399
out = xp.nonzero(x)
94-
if x.ndim == 0:
95-
assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays"
96-
else:
97-
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
100+
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
98101
out_size = math.prod(out[0].shape)
99102
for i in range(len(out)):
100103
assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"

0 commit comments

Comments
 (0)