Skip to content

Commit d292956

Browse files
committed
Update sorting_tests/test_sort.py
1 parent 4e7b432 commit d292956

File tree

1 file changed

+17
-32
lines changed

1 file changed

+17
-32
lines changed

tests/third_party/cupy/sorting_tests/test_sort.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def get_array_module(*args):
2020

2121

2222
class TestSort(unittest.TestCase):
23+
2324
# Test ranks
2425

2526
def test_sort_zero_dim(self):
@@ -68,11 +69,11 @@ def test_sort_contiguous(self, xp):
6869
a.sort()
6970
return a
7071

71-
@testing.numpy_cupy_array_equal()
72-
def test_sort_non_contiguous(self, xp):
73-
a = testing.shaped_random((10,), xp)[::2] # Non contiguous view
74-
a.sort()
75-
return a
72+
@pytest.mark.skip("non-contiguous array is supported")
73+
def test_sort_non_contiguous(self):
74+
a = testing.shaped_random((10,), cupy)[::2] # Non contiguous view
75+
with self.assertRaises(NotImplementedError):
76+
a.sort()
7677

7778
@testing.numpy_cupy_array_equal()
7879
def test_external_sort_contiguous(self, xp):
@@ -214,6 +215,7 @@ def test_large(self, xp):
214215

215216
@pytest.mark.skip("lexsort() is not implemented yet")
216217
class TestLexsort(unittest.TestCase):
218+
217219
# Test ranks
218220

219221
# TODO(niboshi): Fix xfail
@@ -298,12 +300,15 @@ def test_F_order(self, xp):
298300
)
299301
)
300302
class TestArgsort(unittest.TestCase):
301-
def argsort(self, a, axis=-1, kind=None):
303+
304+
def argsort(self, a, axis=-1):
302305
if self.external:
306+
# Need to explicitly specify kind="stable"
307+
# numpy uses "quicksort" as default
303308
xp = cupy.get_array_module(a)
304-
return xp.argsort(a, axis=axis, kind=kind)
309+
return xp.argsort(a, axis=axis, kind="stable")
305310
else:
306-
return a.argsort(axis=axis, kind=kind)
311+
return a.argsort(axis=axis, kind="stable")
307312

308313
# Test base cases
309314

@@ -319,7 +324,7 @@ def test_argsort_zero_dim(self, xp, dtype):
319324
@testing.numpy_cupy_array_equal()
320325
def test_argsort_one_dim(self, xp, dtype):
321326
a = testing.shaped_random((10,), xp, dtype)
322-
return self.argsort(a, axis=-1, kind="stable")
327+
return self.argsort(a)
323328

324329
@testing.for_all_dtypes()
325330
@testing.numpy_cupy_array_equal()
@@ -414,30 +419,8 @@ def test_nan2(self, xp, dtype):
414419
return self.argsort(a)
415420

416421

417-
@pytest.mark.skip("msort() is deprecated")
418-
class TestMsort(unittest.TestCase):
419-
# Test base cases
420-
421-
def test_msort_zero_dim(self):
422-
for xp in (numpy, cupy):
423-
a = testing.shaped_random((), xp)
424-
with pytest.raises(AxisError):
425-
xp.msort(a)
426-
427-
@testing.for_all_dtypes()
428-
@testing.numpy_cupy_array_equal()
429-
def test_msort_one_dim(self, xp, dtype):
430-
a = testing.shaped_random((10,), xp, dtype)
431-
return xp.msort(a)
432-
433-
@testing.for_all_dtypes()
434-
@testing.numpy_cupy_array_equal()
435-
def test_msort_multi_dim(self, xp, dtype):
436-
a = testing.shaped_random((2, 3), xp, dtype)
437-
return xp.msort(a)
438-
439-
440422
class TestSort_complex(unittest.TestCase):
423+
441424
def test_sort_complex_zero_dim(self):
442425
for xp in (numpy, cupy):
443426
a = testing.shaped_random((), xp)
@@ -474,6 +457,7 @@ def test_sort_complex_nan(self, xp, dtype):
474457
)
475458
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
476459
class TestPartition(unittest.TestCase):
460+
477461
def partition(self, a, kth, axis=-1):
478462
if self.external:
479463
xp = cupy.get_array_module(a)
@@ -622,6 +606,7 @@ def test_partition_invalid_negative_axis2(self):
622606
)
623607
@pytest.mark.skip("not fully supported yet")
624608
class TestArgpartition(unittest.TestCase):
609+
625610
def argpartition(self, a, kth, axis=-1):
626611
if self.external:
627612
xp = cupy.get_array_module(a)

0 commit comments

Comments
 (0)