Skip to content

Commit bc31369

Browse files
authored
Merge c1cbdba into d06277f
2 parents d06277f + c1cbdba commit bc31369

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

dpnp/dpnp_iface_sorting.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ def _wrap_sort_argsort(
6666
raise NotImplementedError(
6767
"order keyword argument is only supported with its default value."
6868
)
69-
if kind is not None and kind != "stable":
70-
raise NotImplementedError(
71-
"kind keyword argument can only be None or 'stable'."
69+
if kind is not None and stable is not None:
70+
raise ValueError(
71+
"`kind` and `stable` parameters can't be provided at the same time."
72+
" Use only one of them."
7273
)
7374

7475
usm_a = dpnp.get_usm_ndarray(a)
@@ -77,11 +78,11 @@ def _wrap_sort_argsort(
7778
axis = -1
7879

7980
axis = normalize_axis_index(axis, ndim=usm_a.ndim)
80-
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable)
81+
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable, kind=kind)
8182
return dpnp_array._create_from_usm_ndarray(usm_res)
8283

8384

84-
def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
85+
def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
8586
"""
8687
Returns the indices that would sort an array.
8788
@@ -94,9 +95,9 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
9495
axis : {None, int}, optional
9596
Axis along which to sort. If ``None``, the array is flattened before
9697
sorting. The default is ``-1``, which sorts along the last axis.
97-
kind : {None, "stable"}, optional
98+
kind : {None, "stable", "mergesort", "radixsort"}, optional
9899
Sorting algorithm. Default is ``None``, which is equivalent to
99-
``"stable"``. Unlike NumPy, no other option is accepted here.
100+
``"stable"``.
100101
stable : {None, bool}, optional
101102
Sort stability. If ``True``, the returned array will maintain
102103
the relative order of ``a`` values which compare as equal.
@@ -121,8 +122,9 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
121122
Limitations
122123
-----------
123124
Parameters `order` is only supported with its default value.
124-
Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent.
125125
Otherwise ``NotImplementedError`` exception will be raised.
126+
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
127+
126128
127129
See Also
128130
--------
@@ -203,7 +205,7 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
203205
return call_origin(numpy.partition, x1, kth, axis, kind, order)
204206

205207

206-
def sort(a, axis=-1, kind=None, order=None, *, stable=True):
208+
def sort(a, axis=-1, kind=None, order=None, *, stable=None):
207209
"""
208210
Return a sorted copy of an array.
209211
@@ -216,9 +218,9 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=True):
216218
axis : {None, int}, optional
217219
Axis along which to sort. If ``None``, the array is flattened before
218220
sorting. The default is ``-1``, which sorts along the last axis.
219-
kind : {None, "stable"}, optional
221+
kind : {None, "stable", "mergesort", "radixsort"}, optional
220222
Sorting algorithm. Default is ``None``, which is equivalent to
221-
``"stable"``. Unlike NumPy, no other option is accepted here.
223+
``"stable"``.
222224
stable : {None, bool}, optional
223225
Sort stability. If ``True``, the returned array will maintain
224226
the relative order of ``a`` values which compare as equal.
@@ -239,8 +241,8 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=True):
239241
Limitations
240242
-----------
241243
Parameters `order` is only supported with its default value.
242-
Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent.
243244
Otherwise ``NotImplementedError`` exception will be raised.
245+
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
244246
245247
See Also
246248
--------

tests/test_sort.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_argsort_ndarray(self, dtype, axis):
6262
expected = np_array.argsort(axis=axis)
6363
assert_dtype_allclose(result, expected)
6464

65-
@pytest.mark.parametrize("kind", [None, "stable"])
65+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
6666
def test_sort_kind(self, kind):
6767
np_array = numpy.repeat(numpy.arange(10), 10)
6868
dp_array = dpnp.array(np_array)
@@ -308,7 +308,7 @@ def test_sort_ndarray(self, dtype, axis):
308308
np_array.sort(axis=axis)
309309
assert_dtype_allclose(dp_array, np_array)
310310

311-
@pytest.mark.parametrize("kind", [None, "stable"])
311+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
312312
def test_sort_kind(self, kind):
313313
np_array = numpy.repeat(numpy.arange(10), 10)
314314
dp_array = dpnp.array(np_array)
@@ -347,15 +347,20 @@ def test_sort_zero_dim(self):
347347
expected = numpy.sort(np_array, axis=None)
348348
assert_dtype_allclose(result, expected)
349349

350-
def test_sort_notimplemented(self):
350+
def test_sort_error(self):
351351
dp_array = dpnp.arange(10)
352352

353-
with pytest.raises(NotImplementedError):
353+
# quicksort is currently not supported
354+
with pytest.raises(ValueError):
354355
dpnp.sort(dp_array, kind="quicksort")
355356

356357
with pytest.raises(NotImplementedError):
357358
dpnp.sort(dp_array, order=["age"])
358359

360+
# both kind and stable are given
361+
with pytest.raises(ValueError):
362+
dpnp.sort(dp_array, kind="mergesort", stable=True)
363+
359364

360365
class TestSortComplex:
361366
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)