Skip to content

Commit 51f10b8

Browse files
authored
Merge 55c59c2 into cb31797
2 parents cb31797 + 55c59c2 commit 51f10b8

File tree

2 files changed

+86
-47
lines changed

2 files changed

+86
-47
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
11621162
v : {scalar, array_like}
11631163
Values to be put into `a`. Must be broadcastable to the result shape
11641164
``a.shape[:axis] + ind.shape + a.shape[axis+1:]``.
1165-
axis {None, int}, optional
1165+
axis : {None, int}, optional
11661166
The axis along which the values will be placed. If `a` is 1-D array,
11671167
this argument is optional.
11681168
Default: ``None``.
@@ -1490,7 +1490,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
14901490
return dpnp.get_result_array(result, out)
14911491

14921492

1493-
def take_along_axis(a, indices, axis):
1493+
def take_along_axis(a, indices, axis, mode="wrap"):
14941494
"""
14951495
Take values from the input array by matching 1d index and data slices.
14961496
@@ -1511,15 +1511,24 @@ def take_along_axis(a, indices, axis):
15111511
Indices to take along each 1d slice of `a`. This must match the
15121512
dimension of the input array, but dimensions ``Ni`` and ``Nj``
15131513
only need to broadcast against `a`.
1514-
axis : int
1514+
axis : {None, int}
15151515
The axis to take 1d slices along. If axis is ``None``, the input
15161516
array is treated as if it had first been flattened to 1d,
15171517
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
1518+
mode : {"wrap", "clip"}, optional
1519+
Specifies how out-of-bounds indices will be handled. Possible values
1520+
are:
1521+
1522+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1523+
negative indices.
1524+
- ``"clip"``: clips indices to (``0 <= i < n``).
1525+
1526+
Default: ``"wrap"``.
15181527
15191528
Returns
15201529
-------
15211530
out : dpnp.ndarray
1522-
The indexed result.
1531+
The indexed result of the same data type as `a`.
15231532
15241533
See Also
15251534
--------
@@ -1579,12 +1588,21 @@ def take_along_axis(a, indices, axis):
15791588
15801589
"""
15811590

1582-
dpnp.check_supported_arrays_type(a, indices)
1583-
15841591
if axis is None:
1585-
a = a.ravel()
1592+
dpnp.check_supported_arrays_type(indices)
1593+
if indices.ndim != 1:
1594+
raise ValueError(
1595+
"when axis=None, `indices` must have a single dimension."
1596+
)
15861597

1587-
return a[_build_along_axis_index(a, indices, axis)]
1598+
a = dpnp.ravel(a)
1599+
axis = 0
1600+
1601+
usm_a = dpnp.get_usm_ndarray(a)
1602+
usm_ind = dpnp.get_usm_ndarray(indices)
1603+
1604+
usm_res = dpt.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode)
1605+
return dpnp_array._create_from_usm_ndarray(usm_res)
15881606

15891607

15901608
def tril_indices(

tests/test_indexing.py

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,13 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
544544
dpnp.put_along_axis(dp_a, dp_ai, values, axis)
545545
assert_array_equal(np_a, dp_a)
546546

547+
@pytest.mark.parametrize("xp", [numpy, dpnp])
548+
@pytest.mark.parametrize("dt", [bool, numpy.float32])
549+
def test_invalid_indices_dtype(self, xp, dt):
550+
a = xp.ones((10, 10))
551+
ind = xp.ones(10, dtype=dt)
552+
assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1)
553+
547554
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
548555
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
549556
def test_broadcast(self, arr_dt, idx_dt):
@@ -673,66 +680,80 @@ def test_argequivalent(self, func, argfunc, kwargs):
673680
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
674681
@pytest.mark.parametrize("ndim", list(range(1, 4)))
675682
def test_multi_dimensions(self, arr_dt, idx_dt, ndim):
676-
np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
677-
np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
683+
a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
684+
ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
678685
(1,) * (ndim - 1) + (4,)
679686
)
680-
681-
dp_a = dpnp.array(np_a, dtype=arr_dt)
682-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
687+
ia, iind = dpnp.array(a), dpnp.array(ind)
683688

684689
for axis in range(ndim):
685-
expected = numpy.take_along_axis(np_a, np_ai, axis)
686-
result = dpnp.take_along_axis(dp_a, dp_ai, axis)
690+
result = dpnp.take_along_axis(ia, iind, axis)
691+
expected = numpy.take_along_axis(a, ind, axis)
687692
assert_array_equal(expected, result)
688693

689694
@pytest.mark.parametrize("xp", [numpy, dpnp])
690-
def test_invalid(self, xp):
695+
def test_not_enough_indices(self, xp):
691696
a = xp.ones((10, 10))
692-
ai = xp.ones((10, 2), dtype=xp.intp)
693-
694-
# not enough indices
695697
assert_raises(ValueError, xp.take_along_axis, a, xp.array(1), axis=1)
696698

697-
# bool arrays not allowed
698-
assert_raises(
699-
IndexError, xp.take_along_axis, a, ai.astype(bool), axis=1
700-
)
699+
@pytest.mark.parametrize("xp", [numpy, dpnp])
700+
@pytest.mark.parametrize("dt", [bool, numpy.float32])
701+
def test_invalid_indices_dtype(self, xp, dt):
702+
a = xp.ones((10, 10))
703+
ind = xp.ones((10, 2), dtype=dt)
704+
assert_raises(IndexError, xp.take_along_axis, a, ind, axis=1)
701705

702-
# float arrays not allowed
703-
assert_raises(
704-
IndexError, xp.take_along_axis, a, ai.astype(numpy.float32), axis=1
705-
)
706+
@pytest.mark.parametrize("xp", [numpy, dpnp])
707+
def test_invalid_axis(self, xp):
708+
a = xp.ones((10, 10))
709+
ind = xp.ones((10, 2), dtype=xp.intp)
710+
assert_raises(AxisError, xp.take_along_axis, a, ind, axis=10)
706711

707-
# invalid axis
708-
assert_raises(AxisError, xp.take_along_axis, a, ai, axis=10)
712+
@pytest.mark.parametrize("xp", [numpy, dpnp])
713+
def test_indices_ndim_axis_none(self, xp):
714+
a = xp.ones((10, 10))
715+
ind = xp.ones((10, 2), dtype=xp.intp)
716+
assert_raises(ValueError, xp.take_along_axis, a, ind, axis=None)
709717

710-
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
718+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
711719
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
712-
def test_empty(self, arr_dt, idx_dt):
713-
np_a = numpy.ones((3, 4, 5), dtype=arr_dt)
714-
np_ai = numpy.ones((3, 0, 5), dtype=idx_dt)
715-
716-
dp_a = dpnp.array(np_a, dtype=arr_dt)
717-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
720+
def test_empty(self, a_dt, idx_dt):
721+
a = numpy.ones((3, 4, 5), dtype=a_dt)
722+
ind = numpy.ones((3, 0, 5), dtype=idx_dt)
723+
ia, iind = dpnp.array(a), dpnp.array(ind)
718724

719-
expected = numpy.take_along_axis(np_a, np_ai, axis=1)
720-
result = dpnp.take_along_axis(dp_a, dp_ai, axis=1)
725+
result = dpnp.take_along_axis(ia, iind, axis=1)
726+
expected = numpy.take_along_axis(a, ind, axis=1)
721727
assert_array_equal(expected, result)
722728

723-
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
729+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
724730
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
725-
def test_broadcast(self, arr_dt, idx_dt):
726-
np_a = numpy.ones((3, 4, 1), dtype=arr_dt)
727-
np_ai = numpy.ones((1, 2, 5), dtype=idx_dt)
728-
729-
dp_a = dpnp.array(np_a, dtype=arr_dt)
730-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
731+
def test_broadcast(self, a_dt, idx_dt):
732+
a = numpy.ones((3, 4, 1), dtype=a_dt)
733+
ind = numpy.ones((1, 2, 5), dtype=idx_dt)
734+
ia, iind = dpnp.array(a), dpnp.array(ind)
731735

732-
expected = numpy.take_along_axis(np_a, np_ai, axis=1)
733-
result = dpnp.take_along_axis(dp_a, dp_ai, axis=1)
736+
result = dpnp.take_along_axis(ia, iind, axis=1)
737+
expected = numpy.take_along_axis(a, ind, axis=1)
734738
assert_array_equal(expected, result)
735739

740+
def test_mode_wrap(self):
741+
a = numpy.array([-2, -1, 0, 1, 2])
742+
ind = numpy.array([-2, 2, -5, 4])
743+
ia, iind = dpnp.array(a), dpnp.array(ind)
744+
745+
result = dpnp.take_along_axis(ia, iind, axis=0, mode="wrap")
746+
expected = numpy.take_along_axis(a, ind, axis=0)
747+
assert_array_equal(result, expected)
748+
749+
def test_mode_clip(self):
750+
a = dpnp.array([-2, -1, 0, 1, 2])
751+
ind = dpnp.array([-2, 2, -5, 4])
752+
753+
# numpy does not support keyword `mode`
754+
result = dpnp.take_along_axis(a, ind, axis=0, mode="clip")
755+
assert (result == dpnp.array([-2, 0, -2, 2])).all()
756+
736757

737758
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
738759
def test_choose():

0 commit comments

Comments
 (0)