@@ -1490,7 +1490,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
1490
1490
return dpnp .get_result_array (result , out )
1491
1491
1492
1492
1493
- def take_along_axis (a , indices , axis ):
1493
+ def take_along_axis (a , indices , axis , mode = "wrap" ):
1494
1494
"""
1495
1495
Take values from the input array by matching 1d index and data slices.
1496
1496
@@ -1511,15 +1511,24 @@ def take_along_axis(a, indices, axis):
1511
1511
Indices to take along each 1d slice of `a`. This must match the
1512
1512
dimension of the input array, but dimensions ``Ni`` and ``Nj``
1513
1513
only need to broadcast against `a`.
1514
- axis : int
1514
+ axis : {None, int}
1515
1515
The axis to take 1d slices along. If axis is ``None``, the input
1516
1516
array is treated as if it had first been flattened to 1d,
1517
1517
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
1518
+ mode : {"wrap", "clip"}, optional
1519
+ Specifies how out-of-bounds indices will be handled. Possible values
1520
+ are:
1521
+
1522
+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1523
+ negative indices.
1524
+ - ``"clip"``: clips indices to (``0 <= i < n``).
1525
+
1526
+ Default: ``"wrap"``.
1518
1527
1519
1528
Returns
1520
1529
-------
1521
1530
out : dpnp.ndarray
1522
- The indexed result.
1531
+ The indexed result of the same data type as `a` .
1523
1532
1524
1533
See Also
1525
1534
--------
@@ -1579,12 +1588,21 @@ def take_along_axis(a, indices, axis):
1579
1588
1580
1589
"""
1581
1590
1582
- dpnp .check_supported_arrays_type (a , indices )
1583
-
1584
1591
if axis is None :
1585
- a = a .ravel ()
1592
+ dpnp .check_supported_arrays_type (indices )
1593
+ if indices .ndim != 1 :
1594
+ raise ValueError (
1595
+ "when axis=None, `indices` must have a single dimension."
1596
+ )
1586
1597
1587
- return a [_build_along_axis_index (a , indices , axis )]
1598
+ a = dpnp .ravel (a )
1599
+ axis = 0
1600
+
1601
+ usm_a = dpnp .get_usm_ndarray (a )
1602
+ usm_ind = dpnp .get_usm_ndarray (indices )
1603
+
1604
+ usm_res = dpt .take_along_axis (usm_a , usm_ind , axis = axis , mode = mode )
1605
+ return dpnp_array ._create_from_usm_ndarray (usm_res )
1588
1606
1589
1607
1590
1608
def tril_indices (
0 commit comments