Skip to content

Commit c139d2c

Browse files
Add take_along_axis arg validation test
1 parent c8479dd commit c139d2c

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,3 +1557,20 @@ def test_take_along_axis():
15571557
assert y1.shape == ind1.shape
15581558
y2 = dpt.take_along_axis(x, ind2, axis=2)
15591559
assert y2.shape == ind2.shape
1560+
1561+
1562+
def test_take_along_axis_validation():
1563+
with pytest.raises(TypeError):
1564+
dpt.take_along_axis(tuple(), list())
1565+
get_queue_or_skip()
1566+
x = dpt.ones(10)
1567+
with pytest.raises(TypeError):
1568+
dpt.take_along_axis(x, list())
1569+
ind_dt = dpt.__array_namespace_info__().default_dtypes(
1570+
device=x.sycl_device
1571+
)["indexing"]
1572+
ind = dpt.zeros(1, dtype=ind_dt)
1573+
with pytest.raises(ValueError):
1574+
dpt.take_along_axis(x, ind, axis=1)
1575+
with pytest.raises(ValueError):
1576+
dpt.take_along_axis(x, ind, axis=0, mode="invalid")

0 commit comments

Comments
 (0)