Skip to content

Commit b43e7e3

Browse files
committed
update tests
1 parent 413a4c8 commit b43e7e3

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

tests/test_sort.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,34 @@
1515

1616

1717
class TestArgsort:
18+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
1819
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
19-
def test_argsort_dtype(self, dtype):
20+
def test_basic(self, kind, dtype):
2021
a = numpy.random.uniform(-5, 5, 10)
2122
np_array = numpy.array(a, dtype=dtype)
2223
dp_array = dpnp.array(np_array)
2324

24-
result = dpnp.argsort(dp_array, kind="stable")
25+
result = dpnp.argsort(dp_array, kind=kind)
2526
expected = numpy.argsort(np_array, kind="stable")
2627
assert_dtype_allclose(result, expected)
2728

29+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
2830
@pytest.mark.parametrize("dtype", get_complex_dtypes())
29-
def test_argsort_complex(self, dtype):
31+
def test_complex(self, kind, dtype):
3032
a = numpy.random.uniform(-5, 5, 10)
3133
b = numpy.random.uniform(-5, 5, 10)
3234
np_array = numpy.array(a + b * 1j, dtype=dtype)
3335
dp_array = dpnp.array(np_array)
3436

35-
result = dpnp.argsort(dp_array)
36-
expected = numpy.argsort(np_array)
37-
assert_dtype_allclose(result, expected)
37+
if kind == "radixsort":
38+
assert_raises(ValueError, dpnp.argsort, dp_array, kind=kind)
39+
else:
40+
result = dpnp.argsort(dp_array, kind=kind)
41+
expected = numpy.argsort(np_array)
42+
assert_dtype_allclose(result, expected)
3843

3944
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
40-
def test_argsort_axis(self, axis):
45+
def test_axis(self, axis):
4146
a = numpy.random.uniform(-10, 10, 36)
4247
np_array = numpy.array(a).reshape(3, 4, 3)
4348
dp_array = dpnp.array(np_array)
@@ -48,7 +53,7 @@ def test_argsort_axis(self, axis):
4853

4954
@pytest.mark.parametrize("dtype", get_all_dtypes())
5055
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1])
51-
def test_argsort_ndarray(self, dtype, axis):
56+
def test_ndarray(self, dtype, axis):
5257
if dtype and issubclass(dtype, numpy.integer):
5358
a = numpy.random.choice(
5459
numpy.arange(-10, 10), replace=False, size=12
@@ -62,8 +67,9 @@ def test_argsort_ndarray(self, dtype, axis):
6267
expected = np_array.argsort(axis=axis)
6368
assert_dtype_allclose(result, expected)
6469

70+
# this test validates that all diffeernt options of kind in dpnp are stable
6571
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
66-
def test_sort_kind(self, kind):
72+
def test_kind(self, kind):
6773
np_array = numpy.repeat(numpy.arange(10), 10)
6874
dp_array = dpnp.array(np_array)
6975

@@ -74,15 +80,15 @@ def test_sort_kind(self, kind):
7480
# `stable` keyword is supported in numpy 2.0 and above
7581
@testing.with_requires("numpy>=2.0")
7682
@pytest.mark.parametrize("stable", [None, False, True])
77-
def test_sort_stable(self, stable):
83+
def test_stable(self, stable):
7884
np_array = numpy.repeat(numpy.arange(10), 10)
7985
dp_array = dpnp.array(np_array)
8086

8187
result = dpnp.argsort(dp_array, stable="stable")
8288
expected = numpy.argsort(np_array, stable=True)
8389
assert_dtype_allclose(result, expected)
8490

85-
def test_argsort_zero_dim(self):
91+
def test_zero_dim(self):
8692
np_array = numpy.array(2.5)
8793
dp_array = dpnp.array(np_array)
8894

@@ -266,29 +272,34 @@ def test_v_scalar(self):
266272

267273

268274
class TestSort:
275+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
269276
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
270-
def test_sort_dtype(self, dtype):
277+
def test_basic(self, kind, dtype):
271278
a = numpy.random.uniform(-5, 5, 10)
272279
np_array = numpy.array(a, dtype=dtype)
273280
dp_array = dpnp.array(np_array)
274281

275-
result = dpnp.sort(dp_array)
282+
result = dpnp.sort(dp_array, kind=kind)
276283
expected = numpy.sort(np_array)
277284
assert_dtype_allclose(result, expected)
278285

286+
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
279287
@pytest.mark.parametrize("dtype", get_complex_dtypes())
280-
def test_sort_complex(self, dtype):
288+
def test_complex(self, kind, dtype):
281289
a = numpy.random.uniform(-5, 5, 10)
282290
b = numpy.random.uniform(-5, 5, 10)
283291
np_array = numpy.array(a + b * 1j, dtype=dtype)
284292
dp_array = dpnp.array(np_array)
285293

286-
result = dpnp.sort(dp_array)
287-
expected = numpy.sort(np_array)
288-
assert_dtype_allclose(result, expected)
294+
if kind == "radixsort":
295+
assert_raises(ValueError, dpnp.argsort, dp_array, kind=kind)
296+
else:
297+
result = dpnp.sort(dp_array, kind=kind)
298+
expected = numpy.sort(np_array)
299+
assert_dtype_allclose(result, expected)
289300

290301
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
291-
def test_sort_axis(self, axis):
302+
def test_axis(self, axis):
292303
a = numpy.random.uniform(-10, 10, 36)
293304
np_array = numpy.array(a).reshape(3, 4, 3)
294305
dp_array = dpnp.array(np_array)
@@ -299,7 +310,7 @@ def test_sort_axis(self, axis):
299310

300311
@pytest.mark.parametrize("dtype", get_all_dtypes())
301312
@pytest.mark.parametrize("axis", [-2, -1, 0, 1])
302-
def test_sort_ndarray(self, dtype, axis):
313+
def test_ndarray(self, dtype, axis):
303314
a = numpy.random.uniform(-10, 10, 12)
304315
np_array = numpy.array(a, dtype=dtype).reshape(6, 2)
305316
dp_array = dpnp.array(np_array)
@@ -308,8 +319,9 @@ def test_sort_ndarray(self, dtype, axis):
308319
np_array.sort(axis=axis)
309320
assert_dtype_allclose(dp_array, np_array)
310321

322+
# this test validates that all diffeernt options of kind in dpnp are stable
311323
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
312-
def test_sort_kind(self, kind):
324+
def test_kind(self, kind):
313325
np_array = numpy.repeat(numpy.arange(10), 10)
314326
dp_array = dpnp.array(np_array)
315327

@@ -320,21 +332,21 @@ def test_sort_kind(self, kind):
320332
# `stable` keyword is supported in numpy 2.0 and above
321333
@testing.with_requires("numpy>=2.0")
322334
@pytest.mark.parametrize("stable", [None, False, True])
323-
def test_sort_stable(self, stable):
335+
def test_stable(self, stable):
324336
np_array = numpy.repeat(numpy.arange(10), 10)
325337
dp_array = dpnp.array(np_array)
326338

327339
result = dpnp.sort(dp_array, stable="stable")
328340
expected = numpy.sort(np_array, stable=True)
329341
assert_dtype_allclose(result, expected)
330342

331-
def test_sort_ndarray_axis_none(self):
343+
def test_ndarray_axis_none(self):
332344
a = numpy.random.uniform(-10, 10, 12)
333345
dp_array = dpnp.array(a).reshape(6, 2)
334346
with pytest.raises(TypeError):
335347
dp_array.sort(axis=None)
336348

337-
def test_sort_zero_dim(self):
349+
def test_zero_dim(self):
338350
np_array = numpy.array(2.5)
339351
dp_array = dpnp.array(np_array)
340352

@@ -347,7 +359,7 @@ def test_sort_zero_dim(self):
347359
expected = numpy.sort(np_array, axis=None)
348360
assert_dtype_allclose(result, expected)
349361

350-
def test_sort_error(self):
362+
def test_error(self):
351363
dp_array = dpnp.arange(10)
352364

353365
# quicksort is currently not supported

0 commit comments

Comments
 (0)