Skip to content

Commit 6002c73

Browse files
committed
Implement dpnp.take_along_axis through dpctl.tensor
1 parent 543605c commit 6002c73

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,7 +1490,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
14901490
return dpnp.get_result_array(result, out)
14911491

14921492

1493-
def take_along_axis(a, indices, axis):
1493+
def take_along_axis(a, indices, axis, mode="wrap"):
14941494
"""
14951495
Take values from the input array by matching 1d index and data slices.
14961496
@@ -1511,15 +1511,24 @@ def take_along_axis(a, indices, axis):
15111511
Indices to take along each 1d slice of `a`. This must match the
15121512
dimension of the input array, but dimensions ``Ni`` and ``Nj``
15131513
only need to broadcast against `a`.
1514-
axis : int
1514+
axis : {None, int}
15151515
The axis to take 1d slices along. If axis is ``None``, the input
15161516
array is treated as if it had first been flattened to 1d,
15171517
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
1518+
mode : {"wrap", "clip"}, optional
1519+
Specifies how out-of-bounds indices will be handled. Possible values
1520+
are:
1521+
1522+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1523+
negative indices.
1524+
- ``"clip"``: clips indices to (``0 <= i < n``).
1525+
1526+
Default: ``"wrap"``.
15181527
15191528
Returns
15201529
-------
15211530
out : dpnp.ndarray
1522-
The indexed result.
1531+
The indexed result of the same data type as `a`.
15231532
15241533
See Also
15251534
--------
@@ -1579,12 +1588,21 @@ def take_along_axis(a, indices, axis):
15791588
15801589
"""
15811590

1582-
dpnp.check_supported_arrays_type(a, indices)
1583-
15841591
if axis is None:
1585-
a = a.ravel()
1592+
dpnp.check_supported_arrays_type(indices)
1593+
if indices.ndim != 1:
1594+
raise ValueError(
1595+
"when axis=None, `indices` must have a single dimension."
1596+
)
15861597

1587-
return a[_build_along_axis_index(a, indices, axis)]
1598+
a = dpnp.ravel(a)
1599+
axis = 0
1600+
1601+
usm_a = dpnp.get_usm_ndarray(a)
1602+
usm_ind = dpnp.get_usm_ndarray(indices)
1603+
1604+
usm_res = dpt.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode)
1605+
return dpnp_array._create_from_usm_ndarray(usm_res)
15881606

15891607

15901608
def tril_indices(

0 commit comments

Comments
 (0)