Skip to content

Commit 5fda819

Browse files
authored
Leverage on dpctl.tensor implementation in dpnp.take_along_axis (#1969)
* Implement dpnp.take_along_axis through dpctl.tensor * Added more tests to cover new logic * Increase test coverage * Fix type in docstring of dpnp.put()
1 parent 4a23239 commit 5fda819

File tree

2 files changed

+86
-47
lines changed

2 files changed

+86
-47
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
11741174
v : {scalar, array_like}
11751175
Values to be put into `a`. Must be broadcastable to the result shape
11761176
``a.shape[:axis] + ind.shape + a.shape[axis+1:]``.
1177-
axis {None, int}, optional
1177+
axis : {None, int}, optional
11781178
The axis along which the values will be placed. If `a` is 1-D array,
11791179
this argument is optional.
11801180
Default: ``None``.
@@ -1502,7 +1502,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
15021502
return dpnp.get_result_array(result, out)
15031503

15041504

1505-
def take_along_axis(a, indices, axis):
1505+
def take_along_axis(a, indices, axis, mode="wrap"):
15061506
"""
15071507
Take values from the input array by matching 1d index and data slices.
15081508
@@ -1523,15 +1523,24 @@ def take_along_axis(a, indices, axis):
15231523
Indices to take along each 1d slice of `a`. This must match the
15241524
dimension of the input array, but dimensions ``Ni`` and ``Nj``
15251525
only need to broadcast against `a`.
1526-
axis : int
1526+
axis : {None, int}
15271527
The axis to take 1d slices along. If axis is ``None``, the input
15281528
array is treated as if it had first been flattened to 1d,
15291529
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
1530+
mode : {"wrap", "clip"}, optional
1531+
Specifies how out-of-bounds indices will be handled. Possible values
1532+
are:
1533+
1534+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1535+
negative indices.
1536+
- ``"clip"``: clips indices to (``0 <= i < n``).
1537+
1538+
Default: ``"wrap"``.
15301539
15311540
Returns
15321541
-------
15331542
out : dpnp.ndarray
1534-
The indexed result.
1543+
The indexed result of the same data type as `a`.
15351544
15361545
See Also
15371546
--------
@@ -1591,12 +1600,21 @@ def take_along_axis(a, indices, axis):
15911600
15921601
"""
15931602

1594-
dpnp.check_supported_arrays_type(a, indices)
1595-
15961603
if axis is None:
1597-
a = a.ravel()
1604+
dpnp.check_supported_arrays_type(indices)
1605+
if indices.ndim != 1:
1606+
raise ValueError(
1607+
"when axis=None, `indices` must have a single dimension."
1608+
)
15981609

1599-
return a[_build_along_axis_index(a, indices, axis)]
1610+
a = dpnp.ravel(a)
1611+
axis = 0
1612+
1613+
usm_a = dpnp.get_usm_ndarray(a)
1614+
usm_ind = dpnp.get_usm_ndarray(indices)
1615+
1616+
usm_res = dpt.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode)
1617+
return dpnp_array._create_from_usm_ndarray(usm_res)
16001618

16011619

16021620
def tril_indices(

tests/test_indexing.py

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,13 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
544544
dpnp.put_along_axis(dp_a, dp_ai, values, axis)
545545
assert_array_equal(np_a, dp_a)
546546

547+
@pytest.mark.parametrize("xp", [numpy, dpnp])
548+
@pytest.mark.parametrize("dt", [bool, numpy.float32])
549+
def test_invalid_indices_dtype(self, xp, dt):
550+
a = xp.ones((10, 10))
551+
ind = xp.ones(10, dtype=dt)
552+
assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1)
553+
547554
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
548555
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
549556
def test_broadcast(self, arr_dt, idx_dt):
@@ -673,66 +680,80 @@ def test_argequivalent(self, func, argfunc, kwargs):
673680
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
674681
@pytest.mark.parametrize("ndim", list(range(1, 4)))
675682
def test_multi_dimensions(self, arr_dt, idx_dt, ndim):
676-
np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
677-
np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
683+
a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
684+
ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
678685
(1,) * (ndim - 1) + (4,)
679686
)
680-
681-
dp_a = dpnp.array(np_a, dtype=arr_dt)
682-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
687+
ia, iind = dpnp.array(a), dpnp.array(ind)
683688

684689
for axis in range(ndim):
685-
expected = numpy.take_along_axis(np_a, np_ai, axis)
686-
result = dpnp.take_along_axis(dp_a, dp_ai, axis)
690+
result = dpnp.take_along_axis(ia, iind, axis)
691+
expected = numpy.take_along_axis(a, ind, axis)
687692
assert_array_equal(expected, result)
688693

689694
@pytest.mark.parametrize("xp", [numpy, dpnp])
690-
def test_invalid(self, xp):
695+
def test_not_enough_indices(self, xp):
691696
a = xp.ones((10, 10))
692-
ai = xp.ones((10, 2), dtype=xp.intp)
693-
694-
# not enough indices
695697
assert_raises(ValueError, xp.take_along_axis, a, xp.array(1), axis=1)
696698

697-
# bool arrays not allowed
698-
assert_raises(
699-
IndexError, xp.take_along_axis, a, ai.astype(bool), axis=1
700-
)
699+
@pytest.mark.parametrize("xp", [numpy, dpnp])
700+
@pytest.mark.parametrize("dt", [bool, numpy.float32])
701+
def test_invalid_indices_dtype(self, xp, dt):
702+
a = xp.ones((10, 10))
703+
ind = xp.ones((10, 2), dtype=dt)
704+
assert_raises(IndexError, xp.take_along_axis, a, ind, axis=1)
701705

702-
# float arrays not allowed
703-
assert_raises(
704-
IndexError, xp.take_along_axis, a, ai.astype(numpy.float32), axis=1
705-
)
706+
@pytest.mark.parametrize("xp", [numpy, dpnp])
707+
def test_invalid_axis(self, xp):
708+
a = xp.ones((10, 10))
709+
ind = xp.ones((10, 2), dtype=xp.intp)
710+
assert_raises(AxisError, xp.take_along_axis, a, ind, axis=10)
706711

707-
# invalid axis
708-
assert_raises(AxisError, xp.take_along_axis, a, ai, axis=10)
712+
@pytest.mark.parametrize("xp", [numpy, dpnp])
713+
def test_indices_ndim_axis_none(self, xp):
714+
a = xp.ones((10, 10))
715+
ind = xp.ones((10, 2), dtype=xp.intp)
716+
assert_raises(ValueError, xp.take_along_axis, a, ind, axis=None)
709717

710-
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
718+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
711719
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
712-
def test_empty(self, arr_dt, idx_dt):
713-
np_a = numpy.ones((3, 4, 5), dtype=arr_dt)
714-
np_ai = numpy.ones((3, 0, 5), dtype=idx_dt)
715-
716-
dp_a = dpnp.array(np_a, dtype=arr_dt)
717-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
720+
def test_empty(self, a_dt, idx_dt):
721+
a = numpy.ones((3, 4, 5), dtype=a_dt)
722+
ind = numpy.ones((3, 0, 5), dtype=idx_dt)
723+
ia, iind = dpnp.array(a), dpnp.array(ind)
718724

719-
expected = numpy.take_along_axis(np_a, np_ai, axis=1)
720-
result = dpnp.take_along_axis(dp_a, dp_ai, axis=1)
725+
result = dpnp.take_along_axis(ia, iind, axis=1)
726+
expected = numpy.take_along_axis(a, ind, axis=1)
721727
assert_array_equal(expected, result)
722728

723-
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
729+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
724730
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
725-
def test_broadcast(self, arr_dt, idx_dt):
726-
np_a = numpy.ones((3, 4, 1), dtype=arr_dt)
727-
np_ai = numpy.ones((1, 2, 5), dtype=idx_dt)
728-
729-
dp_a = dpnp.array(np_a, dtype=arr_dt)
730-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
731+
def test_broadcast(self, a_dt, idx_dt):
732+
a = numpy.ones((3, 4, 1), dtype=a_dt)
733+
ind = numpy.ones((1, 2, 5), dtype=idx_dt)
734+
ia, iind = dpnp.array(a), dpnp.array(ind)
731735

732-
expected = numpy.take_along_axis(np_a, np_ai, axis=1)
733-
result = dpnp.take_along_axis(dp_a, dp_ai, axis=1)
736+
result = dpnp.take_along_axis(ia, iind, axis=1)
737+
expected = numpy.take_along_axis(a, ind, axis=1)
734738
assert_array_equal(expected, result)
735739

740+
def test_mode_wrap(self):
741+
a = numpy.array([-2, -1, 0, 1, 2])
742+
ind = numpy.array([-2, 2, -5, 4])
743+
ia, iind = dpnp.array(a), dpnp.array(ind)
744+
745+
result = dpnp.take_along_axis(ia, iind, axis=0, mode="wrap")
746+
expected = numpy.take_along_axis(a, ind, axis=0)
747+
assert_array_equal(result, expected)
748+
749+
def test_mode_clip(self):
750+
a = dpnp.array([-2, -1, 0, 1, 2])
751+
ind = dpnp.array([-2, 2, -5, 4])
752+
753+
# numpy does not support keyword `mode`
754+
result = dpnp.take_along_axis(a, ind, axis=0, mode="clip")
755+
assert (result == dpnp.array([-2, 0, -2, 2])).all()
756+
736757

737758
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
738759
def test_choose():

0 commit comments

Comments
 (0)