@@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
532
532
a = testing .shaped_random ((0 , 1 ), xp , dtype )
533
533
return xp .nanargmin (a , axis = 1 )
534
534
535
+ @testing .for_all_dtypes (no_complex = True )
536
+ @testing .numpy_cupy_allclose ()
537
+ def test_nanargmin_out_float_dtype (self , xp , dtype ):
538
+ a = xp .array ([[0.0 ]])
539
+ b = xp .empty ((1 ), dtype = "int64" )
540
+ xp .nanargmin (a , axis = 1 , out = b )
541
+ return b
542
+
543
+ @testing .for_all_dtypes (no_complex = True )
544
+ @testing .numpy_cupy_array_equal ()
545
+ def test_nanargmin_out_int_dtype (self , xp , dtype ):
546
+ a = xp .array ([1 , 0 ])
547
+ b = xp .empty ((), dtype = "int64" )
548
+ xp .nanargmin (a , out = b )
549
+ return b
550
+
535
551
536
552
class TestNanArgMax :
537
553
@@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
623
639
a = testing .shaped_random ((0 , 1 ), xp , dtype )
624
640
return xp .nanargmax (a , axis = 1 )
625
641
642
+ @testing .for_all_dtypes (no_complex = True )
643
+ @testing .numpy_cupy_allclose ()
644
+ def test_nanargmax_out_float_dtype (self , xp , dtype ):
645
+ a = xp .array ([[0.0 ]])
646
+ b = xp .empty ((1 ), dtype = "int64" )
647
+ xp .nanargmax (a , axis = 1 , out = b )
648
+ return b
649
+
650
+ @testing .for_all_dtypes (no_complex = True )
651
+ @testing .numpy_cupy_array_equal ()
652
+ def test_nanargmax_out_int_dtype (self , xp , dtype ):
653
+ a = xp .array ([0 , 1 ])
654
+ b = xp .empty ((), dtype = "int64" )
655
+ xp .nanargmax (a , out = b )
656
+ return b
657
+
626
658
627
659
@testing .parameterize (
628
660
* testing .product (
@@ -771,7 +803,7 @@ def test_invalid_sorter(self):
771
803
772
804
def test_nonint_sorter (self ):
773
805
for xp in (numpy , cupy ):
774
- x = testing .shaped_arange ((12 ,), xp , xp .float32 )
806
+ x = testing .shaped_arange ((12 ,), xp , xp .float64 )
775
807
bins = xp .array ([10 , 4 , 2 , 1 , 8 ])
776
808
sorter = xp .array ([], dtype = xp .float32 )
777
809
with pytest .raises ((TypeError , ValueError )):
@@ -865,7 +897,7 @@ def test_invalid_sorter(self):
865
897
866
898
def test_nonint_sorter (self ):
867
899
for xp in (numpy , cupy ):
868
- x = testing .shaped_arange ((12 ,), xp , xp .float32 )
900
+ x = testing .shaped_arange ((12 ,), xp , xp .float64 )
869
901
bins = xp .array ([10 , 4 , 2 , 1 , 8 ])
870
902
sorter = xp .array ([], dtype = xp .float32 )
871
903
with pytest .raises ((TypeError , ValueError )):
0 commit comments