Skip to content

Commit c8479dd

Browse files
Basic test for take_along_axis added
1 parent 92ca438 commit c8479dd

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,3 +1537,23 @@ def test_advanced_integer_indexing_cast_indices():
15371537
x = dpt.ones((3, 4, 5, 6), dtype="i4")
15381538
with pytest.raises(ValueError):
15391539
x[inds0, inds1, inds2, ...]
1540+
1541+
1542+
def test_take_along_axis():
1543+
get_queue_or_skip()
1544+
1545+
n0, n1, n2 = 3, 5, 7
1546+
x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
1547+
ind_dt = dpt.__array_namespace_info__().default_dtypes(
1548+
device=x.sycl_device
1549+
)["indexing"]
1550+
ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
1551+
ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
1552+
ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)
1553+
1554+
y0 = dpt.take_along_axis(x, ind0, axis=0)
1555+
assert y0.shape == ind0.shape
1556+
y1 = dpt.take_along_axis(x, ind1, axis=1)
1557+
assert y1.shape == ind1.shape
1558+
y2 = dpt.take_along_axis(x, ind2, axis=2)
1559+
assert y2.shape == ind2.shape

0 commit comments

Comments
 (0)