15
15
16
16
17
17
class TestArgsort :
18
+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
18
19
@pytest .mark .parametrize ("dtype" , get_all_dtypes (no_complex = True ))
19
- def test_argsort_dtype (self , dtype ):
20
+ def test_basic (self , kind , dtype ):
20
21
a = numpy .random .uniform (- 5 , 5 , 10 )
21
22
np_array = numpy .array (a , dtype = dtype )
22
23
dp_array = dpnp .array (np_array )
23
24
24
- result = dpnp .argsort (dp_array , kind = "stable" )
25
+ result = dpnp .argsort (dp_array , kind = kind )
25
26
expected = numpy .argsort (np_array , kind = "stable" )
26
27
assert_dtype_allclose (result , expected )
27
28
29
+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
28
30
@pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
29
- def test_argsort_complex (self , dtype ):
31
+ def test_complex (self , kind , dtype ):
30
32
a = numpy .random .uniform (- 5 , 5 , 10 )
31
33
b = numpy .random .uniform (- 5 , 5 , 10 )
32
34
np_array = numpy .array (a + b * 1j , dtype = dtype )
33
35
dp_array = dpnp .array (np_array )
34
36
35
- result = dpnp .argsort (dp_array )
36
- expected = numpy .argsort (np_array )
37
- assert_dtype_allclose (result , expected )
37
+ if kind == "radixsort" :
38
+ assert_raises (ValueError , dpnp .argsort , dp_array , kind = kind )
39
+ else :
40
+ result = dpnp .argsort (dp_array , kind = kind )
41
+ expected = numpy .argsort (np_array )
42
+ assert_dtype_allclose (result , expected )
38
43
39
44
@pytest .mark .parametrize ("axis" , [None , - 2 , - 1 , 0 , 1 , 2 ])
40
- def test_argsort_axis (self , axis ):
45
+ def test_axis (self , axis ):
41
46
a = numpy .random .uniform (- 10 , 10 , 36 )
42
47
np_array = numpy .array (a ).reshape (3 , 4 , 3 )
43
48
dp_array = dpnp .array (np_array )
@@ -48,7 +53,7 @@ def test_argsort_axis(self, axis):
48
53
49
54
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
50
55
@pytest .mark .parametrize ("axis" , [None , - 2 , - 1 , 0 , 1 ])
51
- def test_argsort_ndarray (self , dtype , axis ):
56
+ def test_ndarray (self , dtype , axis ):
52
57
if dtype and issubclass (dtype , numpy .integer ):
53
58
a = numpy .random .choice (
54
59
numpy .arange (- 10 , 10 ), replace = False , size = 12
@@ -62,8 +67,9 @@ def test_argsort_ndarray(self, dtype, axis):
62
67
expected = np_array .argsort (axis = axis )
63
68
assert_dtype_allclose (result , expected )
64
69
65
- @pytest .mark .parametrize ("kind" , [None , "stable" ])
66
- def test_sort_kind (self , kind ):
70
+ # this test validates that all different options of kind in dpnp are stable
71
+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
72
+ def test_kind (self , kind ):
67
73
np_array = numpy .repeat (numpy .arange (10 ), 10 )
68
74
dp_array = dpnp .array (np_array )
69
75
@@ -74,15 +80,15 @@ def test_sort_kind(self, kind):
74
80
# `stable` keyword is supported in numpy 2.0 and above
75
81
@testing .with_requires ("numpy>=2.0" )
76
82
@pytest .mark .parametrize ("stable" , [None , False , True ])
77
- def test_sort_stable (self , stable ):
83
+ def test_stable (self , stable ):
78
84
np_array = numpy .repeat (numpy .arange (10 ), 10 )
79
85
dp_array = dpnp .array (np_array )
80
86
81
87
result = dpnp .argsort (dp_array , stable = "stable" )
82
88
expected = numpy .argsort (np_array , stable = True )
83
89
assert_dtype_allclose (result , expected )
84
90
85
- def test_argsort_zero_dim (self ):
91
+ def test_zero_dim (self ):
86
92
np_array = numpy .array (2.5 )
87
93
dp_array = dpnp .array (np_array )
88
94
@@ -266,29 +272,34 @@ def test_v_scalar(self):
266
272
267
273
268
274
class TestSort :
275
+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
269
276
@pytest .mark .parametrize ("dtype" , get_all_dtypes (no_complex = True ))
270
- def test_sort_dtype (self , dtype ):
277
+ def test_basic (self , kind , dtype ):
271
278
a = numpy .random .uniform (- 5 , 5 , 10 )
272
279
np_array = numpy .array (a , dtype = dtype )
273
280
dp_array = dpnp .array (np_array )
274
281
275
- result = dpnp .sort (dp_array )
282
+ result = dpnp .sort (dp_array , kind = kind )
276
283
expected = numpy .sort (np_array )
277
284
assert_dtype_allclose (result , expected )
278
285
286
+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
279
287
@pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
280
- def test_sort_complex (self , dtype ):
288
+ def test_complex (self , kind , dtype ):
281
289
a = numpy .random .uniform (- 5 , 5 , 10 )
282
290
b = numpy .random .uniform (- 5 , 5 , 10 )
283
291
np_array = numpy .array (a + b * 1j , dtype = dtype )
284
292
dp_array = dpnp .array (np_array )
285
293
286
- result = dpnp .sort (dp_array )
287
- expected = numpy .sort (np_array )
288
- assert_dtype_allclose (result , expected )
294
+ if kind == "radixsort" :
295
+ assert_raises (ValueError , dpnp .argsort , dp_array , kind = kind )
296
+ else :
297
+ result = dpnp .sort (dp_array , kind = kind )
298
+ expected = numpy .sort (np_array )
299
+ assert_dtype_allclose (result , expected )
289
300
290
301
@pytest .mark .parametrize ("axis" , [None , - 2 , - 1 , 0 , 1 , 2 ])
291
- def test_sort_axis (self , axis ):
302
+ def test_axis (self , axis ):
292
303
a = numpy .random .uniform (- 10 , 10 , 36 )
293
304
np_array = numpy .array (a ).reshape (3 , 4 , 3 )
294
305
dp_array = dpnp .array (np_array )
@@ -299,7 +310,7 @@ def test_sort_axis(self, axis):
299
310
300
311
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
301
312
@pytest .mark .parametrize ("axis" , [- 2 , - 1 , 0 , 1 ])
302
- def test_sort_ndarray (self , dtype , axis ):
313
+ def test_ndarray (self , dtype , axis ):
303
314
a = numpy .random .uniform (- 10 , 10 , 12 )
304
315
np_array = numpy .array (a , dtype = dtype ).reshape (6 , 2 )
305
316
dp_array = dpnp .array (np_array )
@@ -308,8 +319,9 @@ def test_sort_ndarray(self, dtype, axis):
308
319
np_array .sort (axis = axis )
309
320
assert_dtype_allclose (dp_array , np_array )
310
321
311
- @pytest .mark .parametrize ("kind" , [None , "stable" ])
312
- def test_sort_kind (self , kind ):
322
+ # this test validates that all different options of kind in dpnp are stable
323
+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
324
+ def test_kind (self , kind ):
313
325
np_array = numpy .repeat (numpy .arange (10 ), 10 )
314
326
dp_array = dpnp .array (np_array )
315
327
@@ -320,21 +332,21 @@ def test_sort_kind(self, kind):
320
332
# `stable` keyword is supported in numpy 2.0 and above
321
333
@testing .with_requires ("numpy>=2.0" )
322
334
@pytest .mark .parametrize ("stable" , [None , False , True ])
323
- def test_sort_stable (self , stable ):
335
+ def test_stable (self , stable ):
324
336
np_array = numpy .repeat (numpy .arange (10 ), 10 )
325
337
dp_array = dpnp .array (np_array )
326
338
327
339
result = dpnp .sort (dp_array , stable = "stable" )
328
340
expected = numpy .sort (np_array , stable = True )
329
341
assert_dtype_allclose (result , expected )
330
342
331
- def test_sort_ndarray_axis_none (self ):
343
+ def test_ndarray_axis_none (self ):
332
344
a = numpy .random .uniform (- 10 , 10 , 12 )
333
345
dp_array = dpnp .array (a ).reshape (6 , 2 )
334
346
with pytest .raises (TypeError ):
335
347
dp_array .sort (axis = None )
336
348
337
- def test_sort_zero_dim (self ):
349
+ def test_zero_dim (self ):
338
350
np_array = numpy .array (2.5 )
339
351
dp_array = dpnp .array (np_array )
340
352
@@ -347,15 +359,20 @@ def test_sort_zero_dim(self):
347
359
expected = numpy .sort (np_array , axis = None )
348
360
assert_dtype_allclose (result , expected )
349
361
350
- def test_sort_notimplemented (self ):
362
+ def test_error (self ):
351
363
dp_array = dpnp .arange (10 )
352
364
353
- with pytest .raises (NotImplementedError ):
365
+ # quicksort is currently not supported
366
+ with pytest .raises (ValueError ):
354
367
dpnp .sort (dp_array , kind = "quicksort" )
355
368
356
369
with pytest .raises (NotImplementedError ):
357
370
dpnp .sort (dp_array , order = ["age" ])
358
371
372
+ # both kind and stable are given
373
+ with pytest .raises (ValueError ):
374
+ dpnp .sort (dp_array , kind = "mergesort" , stable = True )
375
+
359
376
360
377
class TestSortComplex :
361
378
@pytest .mark .parametrize (
0 commit comments