Skip to content

Commit dcd3d98

Browse files
committed
Add new nanargmin/nanargmax tests
1 parent 0bfe1e3 commit dcd3d98

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

dpnp/tests/third_party/cupy/sorting_tests/test_search.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
532532
a = testing.shaped_random((0, 1), xp, dtype)
533533
return xp.nanargmin(a, axis=1)
534534

535+
@testing.for_all_dtypes(no_complex=True)
536+
@testing.numpy_cupy_allclose()
537+
def test_nanargmin_out_float_dtype(self, xp, dtype):
538+
a = xp.array([[0.0]])
539+
b = xp.empty((1), dtype="int64")
540+
xp.nanargmin(a, axis=1, out=b)
541+
return b
542+
543+
@testing.for_all_dtypes(no_complex=True)
544+
@testing.numpy_cupy_array_equal()
545+
def test_nanargmin_out_int_dtype(self, xp, dtype):
546+
a = xp.array([1, 0])
547+
b = xp.empty((), dtype="int64")
548+
xp.nanargmin(a, out=b)
549+
return b
550+
535551

536552
class TestNanArgMax:
537553

@@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
623639
a = testing.shaped_random((0, 1), xp, dtype)
624640
return xp.nanargmax(a, axis=1)
625641

642+
@testing.for_all_dtypes(no_complex=True)
643+
@testing.numpy_cupy_allclose()
644+
def test_nanargmax_out_float_dtype(self, xp, dtype):
645+
a = xp.array([[0.0]])
646+
b = xp.empty((1), dtype="int64")
647+
xp.nanargmax(a, axis=1, out=b)
648+
return b
649+
650+
@testing.for_all_dtypes(no_complex=True)
651+
@testing.numpy_cupy_array_equal()
652+
def test_nanargmax_out_int_dtype(self, xp, dtype):
653+
a = xp.array([0, 1])
654+
b = xp.empty((), dtype="int64")
655+
xp.nanargmax(a, out=b)
656+
return b
657+
626658

627659
@testing.parameterize(
628660
*testing.product(
@@ -771,7 +803,7 @@ def test_invalid_sorter(self):
771803

772804
def test_nonint_sorter(self):
773805
for xp in (numpy, cupy):
774-
x = testing.shaped_arange((12,), xp, xp.float32)
806+
x = testing.shaped_arange((12,), xp, xp.float64)
775807
bins = xp.array([10, 4, 2, 1, 8])
776808
sorter = xp.array([], dtype=xp.float32)
777809
with pytest.raises((TypeError, ValueError)):
@@ -865,7 +897,7 @@ def test_invalid_sorter(self):
865897

866898
def test_nonint_sorter(self):
867899
for xp in (numpy, cupy):
868-
x = testing.shaped_arange((12,), xp, xp.float32)
900+
x = testing.shaped_arange((12,), xp, xp.float64)
869901
bins = xp.array([10, 4, 2, 1, 8])
870902
sorter = xp.array([], dtype=xp.float32)
871903
with pytest.raises((TypeError, ValueError)):

0 commit comments

Comments
 (0)