Skip to content

Commit 44c1ed3

Browse files
vtavanaantonwolfy
andauthored
add support for kind="mergesort" or "radixsort" for dpnp.sort and dpnp.argsort (#2159)
* add support for more option for kind keyword arguments * update tests * fix typo --------- Co-authored-by: Anton <[email protected]>
1 parent 11f251d commit 44c1ed3

File tree

2 files changed

+58
-39
lines changed

2 files changed

+58
-39
lines changed

dpnp/dpnp_iface_sorting.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ def _wrap_sort_argsort(
6666
raise NotImplementedError(
6767
"order keyword argument is only supported with its default value."
6868
)
69-
if kind is not None and kind != "stable":
70-
raise NotImplementedError(
71-
"kind keyword argument can only be None or 'stable'."
69+
if kind is not None and stable is not None:
70+
raise ValueError(
71+
"`kind` and `stable` parameters can't be provided at the same time."
72+
" Use only one of them."
7273
)
7374

7475
usm_a = dpnp.get_usm_ndarray(a)
@@ -77,11 +78,11 @@ def _wrap_sort_argsort(
7778
axis = -1
7879

7980
axis = normalize_axis_index(axis, ndim=usm_a.ndim)
80-
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable)
81+
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable, kind=kind)
8182
return dpnp_array._create_from_usm_ndarray(usm_res)
8283

8384

84-
def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
85+
def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
8586
"""
8687
Returns the indices that would sort an array.
8788
@@ -94,9 +95,9 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
9495
axis : {None, int}, optional
9596
Axis along which to sort. If ``None``, the array is flattened before
9697
sorting. The default is ``-1``, which sorts along the last axis.
97-
kind : {None, "stable"}, optional
98+
kind : {None, "stable", "mergesort", "radixsort"}, optional
9899
Sorting algorithm. Default is ``None``, which is equivalent to
99-
``"stable"``. Unlike NumPy, no other option is accepted here.
100+
``"stable"``.
100101
stable : {None, bool}, optional
101102
Sort stability. If ``True``, the returned array will maintain
102103
the relative order of ``a`` values which compare as equal.
@@ -121,8 +122,9 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
121122
Limitations
122123
-----------
123124
Parameters `order` is only supported with its default value.
124-
Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent.
125125
Otherwise ``NotImplementedError`` exception will be raised.
126+
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
127+
126128
127129
See Also
128130
--------
@@ -203,7 +205,7 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
203205
return call_origin(numpy.partition, x1, kth, axis, kind, order)
204206

205207

206-
def sort(a, axis=-1, kind=None, order=None, *, stable=True):
208+
def sort(a, axis=-1, kind=None, order=None, *, stable=None):
207209
"""
208210
Return a sorted copy of an array.
209211
@@ -216,9 +218,9 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=True):
216218
axis : {None, int}, optional
217219
Axis along which to sort. If ``None``, the array is flattened before
218220
sorting. The default is ``-1``, which sorts along the last axis.
219-
kind : {None, "stable"}, optional
221+
kind : {None, "stable", "mergesort", "radixsort"}, optional
220222
Sorting algorithm. Default is ``None``, which is equivalent to
221-
``"stable"``. Unlike NumPy, no other option is accepted here.
223+
``"stable"``.
222224
stable : {None, bool}, optional
223225
Sort stability. If ``True``, the returned array will maintain
224226
the relative order of ``a`` values which compare as equal.
@@ -239,8 +241,8 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=True):
239241
Limitations
240242
-----------
241243
Parameters `order` is only supported with its default value.
242-
Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent.
243244
Otherwise ``NotImplementedError`` exception will be raised.
245+
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
244246
245247
See Also
246248
--------

tests/test_sort.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,34 @@
1515

1616

1717
class TestArgsort:
18+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
1819
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
19-
def test_argsort_dtype(self, dtype):
20+
def test_basic(self, kind, dtype):
2021
a = numpy.random.uniform(-5, 5, 10)
2122
np_array = numpy.array(a, dtype=dtype)
2223
dp_array = dpnp.array(np_array)
2324

24-
result = dpnp.argsort(dp_array, kind="stable")
25+
result = dpnp.argsort(dp_array, kind=kind)
2526
expected = numpy.argsort(np_array, kind="stable")
2627
assert_dtype_allclose(result, expected)
2728

29+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
2830
@pytest.mark.parametrize("dtype", get_complex_dtypes())
29-
def test_argsort_complex(self, dtype):
31+
def test_complex(self, kind, dtype):
3032
a = numpy.random.uniform(-5, 5, 10)
3133
b = numpy.random.uniform(-5, 5, 10)
3234
np_array = numpy.array(a + b * 1j, dtype=dtype)
3335
dp_array = dpnp.array(np_array)
3436

35-
result = dpnp.argsort(dp_array)
36-
expected = numpy.argsort(np_array)
37-
assert_dtype_allclose(result, expected)
37+
if kind == "radixsort":
38+
assert_raises(ValueError, dpnp.argsort, dp_array, kind=kind)
39+
else:
40+
result = dpnp.argsort(dp_array, kind=kind)
41+
expected = numpy.argsort(np_array)
42+
assert_dtype_allclose(result, expected)
3843

3944
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
40-
def test_argsort_axis(self, axis):
45+
def test_axis(self, axis):
4146
a = numpy.random.uniform(-10, 10, 36)
4247
np_array = numpy.array(a).reshape(3, 4, 3)
4348
dp_array = dpnp.array(np_array)
@@ -48,7 +53,7 @@ def test_argsort_axis(self, axis):
4853

4954
@pytest.mark.parametrize("dtype", get_all_dtypes())
5055
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1])
51-
def test_argsort_ndarray(self, dtype, axis):
56+
def test_ndarray(self, dtype, axis):
5257
if dtype and issubclass(dtype, numpy.integer):
5358
a = numpy.random.choice(
5459
numpy.arange(-10, 10), replace=False, size=12
@@ -62,8 +67,9 @@ def test_argsort_ndarray(self, dtype, axis):
6267
expected = np_array.argsort(axis=axis)
6368
assert_dtype_allclose(result, expected)
6469

65-
@pytest.mark.parametrize("kind", [None, "stable"])
66-
def test_sort_kind(self, kind):
70+
# this test validates that all different options of kind in dpnp are stable
71+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
72+
def test_kind(self, kind):
6773
np_array = numpy.repeat(numpy.arange(10), 10)
6874
dp_array = dpnp.array(np_array)
6975

@@ -74,15 +80,15 @@ def test_sort_kind(self, kind):
7480
# `stable` keyword is supported in numpy 2.0 and above
7581
@testing.with_requires("numpy>=2.0")
7682
@pytest.mark.parametrize("stable", [None, False, True])
77-
def test_sort_stable(self, stable):
83+
def test_stable(self, stable):
7884
np_array = numpy.repeat(numpy.arange(10), 10)
7985
dp_array = dpnp.array(np_array)
8086

8187
result = dpnp.argsort(dp_array, stable="stable")
8288
expected = numpy.argsort(np_array, stable=True)
8389
assert_dtype_allclose(result, expected)
8490

85-
def test_argsort_zero_dim(self):
91+
def test_zero_dim(self):
8692
np_array = numpy.array(2.5)
8793
dp_array = dpnp.array(np_array)
8894

@@ -266,29 +272,34 @@ def test_v_scalar(self):
266272

267273

268274
class TestSort:
275+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
269276
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
270-
def test_sort_dtype(self, dtype):
277+
def test_basic(self, kind, dtype):
271278
a = numpy.random.uniform(-5, 5, 10)
272279
np_array = numpy.array(a, dtype=dtype)
273280
dp_array = dpnp.array(np_array)
274281

275-
result = dpnp.sort(dp_array)
282+
result = dpnp.sort(dp_array, kind=kind)
276283
expected = numpy.sort(np_array)
277284
assert_dtype_allclose(result, expected)
278285

286+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
279287
@pytest.mark.parametrize("dtype", get_complex_dtypes())
280-
def test_sort_complex(self, dtype):
288+
def test_complex(self, kind, dtype):
281289
a = numpy.random.uniform(-5, 5, 10)
282290
b = numpy.random.uniform(-5, 5, 10)
283291
np_array = numpy.array(a + b * 1j, dtype=dtype)
284292
dp_array = dpnp.array(np_array)
285293

286-
result = dpnp.sort(dp_array)
287-
expected = numpy.sort(np_array)
288-
assert_dtype_allclose(result, expected)
294+
if kind == "radixsort":
295+
assert_raises(ValueError, dpnp.argsort, dp_array, kind=kind)
296+
else:
297+
result = dpnp.sort(dp_array, kind=kind)
298+
expected = numpy.sort(np_array)
299+
assert_dtype_allclose(result, expected)
289300

290301
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
291-
def test_sort_axis(self, axis):
302+
def test_axis(self, axis):
292303
a = numpy.random.uniform(-10, 10, 36)
293304
np_array = numpy.array(a).reshape(3, 4, 3)
294305
dp_array = dpnp.array(np_array)
@@ -299,7 +310,7 @@ def test_sort_axis(self, axis):
299310

300311
@pytest.mark.parametrize("dtype", get_all_dtypes())
301312
@pytest.mark.parametrize("axis", [-2, -1, 0, 1])
302-
def test_sort_ndarray(self, dtype, axis):
313+
def test_ndarray(self, dtype, axis):
303314
a = numpy.random.uniform(-10, 10, 12)
304315
np_array = numpy.array(a, dtype=dtype).reshape(6, 2)
305316
dp_array = dpnp.array(np_array)
@@ -308,8 +319,9 @@ def test_sort_ndarray(self, dtype, axis):
308319
np_array.sort(axis=axis)
309320
assert_dtype_allclose(dp_array, np_array)
310321

311-
@pytest.mark.parametrize("kind", [None, "stable"])
312-
def test_sort_kind(self, kind):
322+
# this test validates that all different options of kind in dpnp are stable
323+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
324+
def test_kind(self, kind):
313325
np_array = numpy.repeat(numpy.arange(10), 10)
314326
dp_array = dpnp.array(np_array)
315327

@@ -320,21 +332,21 @@ def test_sort_kind(self, kind):
320332
# `stable` keyword is supported in numpy 2.0 and above
321333
@testing.with_requires("numpy>=2.0")
322334
@pytest.mark.parametrize("stable", [None, False, True])
323-
def test_sort_stable(self, stable):
335+
def test_stable(self, stable):
324336
np_array = numpy.repeat(numpy.arange(10), 10)
325337
dp_array = dpnp.array(np_array)
326338

327339
result = dpnp.sort(dp_array, stable="stable")
328340
expected = numpy.sort(np_array, stable=True)
329341
assert_dtype_allclose(result, expected)
330342

331-
def test_sort_ndarray_axis_none(self):
343+
def test_ndarray_axis_none(self):
332344
a = numpy.random.uniform(-10, 10, 12)
333345
dp_array = dpnp.array(a).reshape(6, 2)
334346
with pytest.raises(TypeError):
335347
dp_array.sort(axis=None)
336348

337-
def test_sort_zero_dim(self):
349+
def test_zero_dim(self):
338350
np_array = numpy.array(2.5)
339351
dp_array = dpnp.array(np_array)
340352

@@ -347,15 +359,20 @@ def test_sort_zero_dim(self):
347359
expected = numpy.sort(np_array, axis=None)
348360
assert_dtype_allclose(result, expected)
349361

350-
def test_sort_notimplemented(self):
362+
def test_error(self):
351363
dp_array = dpnp.arange(10)
352364

353-
with pytest.raises(NotImplementedError):
365+
# quicksort is currently not supported
366+
with pytest.raises(ValueError):
354367
dpnp.sort(dp_array, kind="quicksort")
355368

356369
with pytest.raises(NotImplementedError):
357370
dpnp.sort(dp_array, order=["age"])
358371

372+
# both kind and stable are given
373+
with pytest.raises(ValueError):
374+
dpnp.sort(dp_array, kind="mergesort", stable=True)
375+
359376

360377
class TestSortComplex:
361378
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)