Skip to content

Commit d415fe2

Browse files
authored
Add descending keyword argument to dpnp.sort and dpnp.argsort (#2269)
The PR proposes to add `descending` keyword argument to sorting functions, including `dpnp.sort`, `dpnp.argsort` and `dpnp.ndarray.sort`, `dpnp.ndarray.argsort` methods. The keyword is mandated according to python array API. The corresponding muted tests are enabled in python array API compliance scope.
1 parent 356184a commit d415fe2

File tree

4 files changed

+199
-38
lines changed

4 files changed

+199
-38
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip
4747
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
4848
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
4949

50-
# missing 'descending' keyword argument
51-
array_api_tests/test_signatures.py::test_func_signature[argsort]
52-
array_api_tests/test_signatures.py::test_func_signature[sort]
53-
array_api_tests/test_sorting_functions.py::test_argsort
54-
array_api_tests/test_sorting_functions.py::test_sort
55-
5650
# missing 'correction' keyword argument
5751
array_api_tests/test_signatures.py::test_func_signature[std]
5852
array_api_tests/test_signatures.py::test_func_signature[var]

dpnp/dpnp_array.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -700,14 +700,63 @@ def argmin(self, axis=None, out=None, *, keepdims=False):
700700

701701
# 'argpartition',
702702

703-
def argsort(self, axis=-1, kind=None, order=None):
703+
def argsort(
704+
self, axis=-1, kind=None, order=None, *, descending=False, stable=None
705+
):
704706
"""
705707
Return an ndarray of indices that sort the array along the specified axis.
706708
707709
Refer to :obj:`dpnp.argsort` for full documentation.
708710
711+
Parameters
712+
----------
713+
axis : {None, int}, optional
714+
Axis along which to sort. If ``None``, the array is flattened
715+
before sorting. The default is ``-1``, which sorts along the last
716+
axis.
717+
Default: ``-1``.
718+
kind : {None, "stable", "mergesort", "radixsort"}, optional
719+
Sorting algorithm. The default is ``None``, which uses parallel
720+
merge-sort or parallel radix-sort algorithms depending on the array
721+
data type.
722+
Default: ``None``.
723+
descending : bool, optional
724+
Sort order. If ``True``, the array must be sorted in descending
725+
order (by value). If ``False``, the array must be sorted in
726+
ascending order (by value).
727+
Default: ``False``.
728+
stable : {None, bool}, optional
729+
Sort stability. If ``True``, the returned array will maintain the
730+
relative order of `a` values which compare as equal. The same
731+
behavior applies when set to ``False`` or ``None``.
732+
Internally, this option selects ``kind="stable"``.
733+
Default: ``None``.
734+
735+
See Also
736+
--------
737+
:obj:`dpnp.sort` : Return a sorted copy of an array.
738+
:obj:`dpnp.argsort` : Return the indices that would sort an array.
739+
:obj:`dpnp.lexsort` : Indirect stable sort on multiple keys.
740+
:obj:`dpnp.searchsorted` : Find elements in a sorted array.
741+
:obj:`dpnp.partition` : Partial sort.
742+
743+
Examples
744+
--------
745+
>>> import dpnp as np
746+
>>> a = np.array([3, 1, 2])
747+
>>> a.argsort()
748+
array([1, 2, 0])
749+
750+
>>> a = np.array([[0, 3], [2, 2]])
751+
>>> a.argsort(axis=0)
752+
array([[0, 1],
753+
[1, 0]])
754+
709755
"""
710-
return dpnp.argsort(self, axis, kind, order)
756+
757+
return dpnp.argsort(
758+
self, axis, kind, order, descending=descending, stable=stable
759+
)
711760

712761
def asnumpy(self):
713762
"""
@@ -1589,12 +1638,45 @@ def size(self):
15891638

15901639
return self._array_obj.size
15911640

1592-
def sort(self, axis=-1, kind=None, order=None):
1641+
def sort(
1642+
self, axis=-1, kind=None, order=None, *, descending=False, stable=None
1643+
):
15931644
"""
15941645
Sort an array in-place.
15951646
15961647
Refer to :obj:`dpnp.sort` for full documentation.
15971648
1649+
Parameters
1650+
----------
1651+
axis : int, optional
1652+
Axis along which to sort. The default is ``-1``, which sorts along
1653+
the last axis.
1654+
Default: ``-1``.
1655+
kind : {None, "stable", "mergesort", "radixsort"}, optional
1656+
Sorting algorithm. The default is ``None``, which uses parallel
1657+
merge-sort or parallel radix-sort algorithms depending on the array
1658+
data type.
1659+
Default: ``None``.
1660+
descending : bool, optional
1661+
Sort order. If ``True``, the array must be sorted in descending
1662+
order (by value). If ``False``, the array must be sorted in
1663+
ascending order (by value).
1664+
Default: ``False``.
1665+
stable : {None, bool}, optional
1666+
Sort stability. If ``True``, the returned array will maintain the
1667+
relative order of `a` values which compare as equal. The same
1668+
behavior applies when set to ``False`` or ``None``.
1669+
Internally, this option selects ``kind="stable"``.
1670+
Default: ``None``.
1671+
1672+
See Also
1673+
--------
1674+
:obj:`dpnp.sort` : Return a sorted copy of an array.
1675+
:obj:`dpnp.argsort` : Return the indices that would sort an array.
1676+
:obj:`dpnp.lexsort` : Indirect stable sort on multiple keys.
1677+
:obj:`dpnp.searchsorted` : Find elements in a sorted array.
1678+
:obj:`dpnp.partition` : Partial sort.
1679+
15981680
Note
15991681
----
16001682
`axis` in :obj:`dpnp.sort` could be integer or ``None``. If ``None``,
@@ -1605,7 +1687,7 @@ def sort(self, axis=-1, kind=None, order=None):
16051687
Examples
16061688
--------
16071689
>>> import dpnp as np
1608-
>>> a = np.array([[1,4],[3,1]])
1690+
>>> a = np.array([[1, 4], [3, 1]])
16091691
>>> a.sort(axis=1)
16101692
>>> a
16111693
array([[1, 4],
@@ -1621,7 +1703,14 @@ def sort(self, axis=-1, kind=None, order=None):
16211703
raise TypeError(
16221704
"'NoneType' object cannot be interpreted as an integer"
16231705
)
1624-
self[...] = dpnp.sort(self, axis=axis, kind=kind, order=order)
1706+
self[...] = dpnp.sort(
1707+
self,
1708+
axis=axis,
1709+
kind=kind,
1710+
order=order,
1711+
descending=descending,
1712+
stable=stable,
1713+
)
16251714

16261715
def squeeze(self, axis=None):
16271716
"""

dpnp/dpnp_iface_sorting.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@
5858

5959

6060
def _wrap_sort_argsort(
61-
a, _sorting_fn, axis=-1, kind=None, order=None, stable=True
61+
a,
62+
_sorting_fn,
63+
axis=-1,
64+
kind=None,
65+
order=None,
66+
descending=False,
67+
stable=True,
6268
):
6369
"""Wrap a sorting call from dpctl.tensor interface."""
6470

@@ -83,11 +89,15 @@ def _wrap_sort_argsort(
8389
axis = -1
8490

8591
axis = normalize_axis_index(axis, ndim=usm_a.ndim)
86-
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable, kind=kind)
92+
usm_res = _sorting_fn(
93+
usm_a, axis=axis, descending=descending, stable=stable, kind=kind
94+
)
8795
return dpnp_array._create_from_usm_ndarray(usm_res)
8896

8997

90-
def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
98+
def argsort(
99+
a, axis=-1, kind=None, order=None, *, descending=False, stable=None
100+
):
91101
"""
92102
Returns the indices that would sort an array.
93103
@@ -100,13 +110,21 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
100110
axis : {None, int}, optional
101111
Axis along which to sort. If ``None``, the array is flattened before
102112
sorting. The default is ``-1``, which sorts along the last axis.
113+
Default: ``-1``.
103114
kind : {None, "stable", "mergesort", "radixsort"}, optional
104-
Sorting algorithm. Default is ``None``, which is equivalent to
105-
``"stable"``.
115+
Sorting algorithm. The default is ``None``, which uses parallel
116+
merge-sort or parallel radix-sort algorithms depending on the array
117+
data type.
118+
Default: ``None``.
119+
descending : bool, optional
120+
Sort order. If ``True``, the array must be sorted in descending order
121+
(by value). If ``False``, the array must be sorted in ascending order
122+
(by value).
123+
Default: ``False``.
106124
stable : {None, bool}, optional
107-
Sort stability. If ``True``, the returned array will maintain
108-
the relative order of ``a`` values which compare as equal.
109-
The same behavior applies when set to ``False`` or ``None``.
125+
Sort stability. If ``True``, the returned array will maintain the
126+
relative order of `a` values which compare as equal. The same behavior
127+
applies when set to ``False`` or ``None``.
110128
Internally, this option selects ``kind="stable"``.
111129
Default: ``None``.
112130
@@ -130,7 +148,6 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
130148
Otherwise ``NotImplementedError`` exception will be raised.
131149
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
132150
133-
134151
See Also
135152
--------
136153
:obj:`dpnp.ndarray.argsort` : Equivalent method.
@@ -171,7 +188,13 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
171188
"""
172189

173190
return _wrap_sort_argsort(
174-
a, dpt.argsort, axis=axis, kind=kind, order=order, stable=stable
191+
a,
192+
dpt.argsort,
193+
axis=axis,
194+
kind=kind,
195+
order=order,
196+
descending=descending,
197+
stable=stable,
175198
)
176199

177200

@@ -215,7 +238,7 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
215238
return call_origin(numpy.partition, x1, kth, axis, kind, order)
216239

217240

218-
def sort(a, axis=-1, kind=None, order=None, *, stable=None):
241+
def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None):
219242
"""
220243
Return a sorted copy of an array.
221244
@@ -228,13 +251,21 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=None):
228251
axis : {None, int}, optional
229252
Axis along which to sort. If ``None``, the array is flattened before
230253
sorting. The default is ``-1``, which sorts along the last axis.
254+
Default: ``-1``.
231255
kind : {None, "stable", "mergesort", "radixsort"}, optional
232-
Sorting algorithm. Default is ``None``, which is equivalent to
233-
``"stable"``.
256+
Sorting algorithm. The default is ``None``, which uses parallel
257+
merge-sort or parallel radix-sort algorithms depending on the array
258+
data type.
259+
Default: ``None``.
260+
descending : bool, optional
261+
Sort order. If ``True``, the array must be sorted in descending order
262+
(by value). If ``False``, the array must be sorted in ascending order
263+
(by value).
264+
Default: ``False``.
234265
stable : {None, bool}, optional
235-
Sort stability. If ``True``, the returned array will maintain
236-
the relative order of ``a`` values which compare as equal.
237-
The same behavior applies when set to ``False`` or ``None``.
266+
Sort stability. If ``True``, the returned array will maintain the
267+
relative order of `a` values which compare as equal. The same behavior
268+
applies when set to ``False`` or ``None``.
238269
Internally, this option selects ``kind="stable"``.
239270
Default: ``None``.
240271
@@ -265,7 +296,7 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=None):
265296
Examples
266297
--------
267298
>>> import dpnp as np
268-
>>> a = np.array([[1,4],[3,1]])
299+
>>> a = np.array([[1, 4], [3, 1]])
269300
>>> np.sort(a) # sort along the last axis
270301
array([[1, 4],
271302
[1, 3]])
@@ -278,7 +309,13 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=None):
278309
"""
279310

280311
return _wrap_sort_argsort(
281-
a, dpt.sort, axis=axis, kind=kind, order=order, stable=stable
312+
a,
313+
dpt.sort,
314+
axis=axis,
315+
kind=kind,
316+
order=order,
317+
descending=descending,
318+
stable=stable,
282319
)
283320

284321

0 commit comments

Comments
 (0)