We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 92ca438 commit c8479ddCopy full SHA for c8479dd
dpctl/tests/test_usm_ndarray_indexing.py
@@ -1537,3 +1537,23 @@ def test_advanced_integer_indexing_cast_indices():
1537
x = dpt.ones((3, 4, 5, 6), dtype="i4")
1538
with pytest.raises(ValueError):
1539
x[inds0, inds1, inds2, ...]
1540
+
1541
1542
+def test_take_along_axis():
1543
+ get_queue_or_skip()
1544
1545
+ n0, n1, n2 = 3, 5, 7
1546
+ x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
1547
+ ind_dt = dpt.__array_namespace_info__().default_dtypes(
1548
+ device=x.sycl_device
1549
+ )["indexing"]
1550
+ ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
1551
+ ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
1552
+ ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)
1553
1554
+ y0 = dpt.take_along_axis(x, ind0, axis=0)
1555
+ assert y0.shape == ind0.shape
1556
+ y1 = dpt.take_along_axis(x, ind1, axis=1)
1557
+ assert y1.shape == ind1.shape
1558
+ y2 = dpt.take_along_axis(x, ind2, axis=2)
1559
+ assert y2.shape == ind2.shape
0 commit comments