@@ -544,6 +544,13 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
544
544
dpnp .put_along_axis (dp_a , dp_ai , values , axis )
545
545
assert_array_equal (np_a , dp_a )
546
546
547
+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
548
+ @pytest .mark .parametrize ("dt" , [bool , numpy .float32 ])
549
+ def test_invalid_indices_dtype (self , xp , dt ):
550
+ a = xp .ones ((10 , 10 ))
551
+ ind = xp .ones (10 , dtype = dt )
552
+ assert_raises (IndexError , xp .put_along_axis , a , ind , 7 , axis = 1 )
553
+
547
554
@pytest .mark .parametrize ("arr_dt" , get_all_dtypes ())
548
555
@pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
549
556
def test_broadcast (self , arr_dt , idx_dt ):
@@ -673,66 +680,80 @@ def test_argequivalent(self, func, argfunc, kwargs):
673
680
@pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
674
681
@pytest .mark .parametrize ("ndim" , list (range (1 , 4 )))
675
682
def test_multi_dimensions (self , arr_dt , idx_dt , ndim ):
676
- np_a = numpy .arange (4 ** ndim , dtype = arr_dt ).reshape ((4 ,) * ndim )
677
- np_ai = numpy .array ([3 , 0 , 2 , 1 ], dtype = idx_dt ).reshape (
683
+ a = numpy .arange (4 ** ndim , dtype = arr_dt ).reshape ((4 ,) * ndim )
684
+ ind = numpy .array ([3 , 0 , 2 , 1 ], dtype = idx_dt ).reshape (
678
685
(1 ,) * (ndim - 1 ) + (4 ,)
679
686
)
680
-
681
- dp_a = dpnp .array (np_a , dtype = arr_dt )
682
- dp_ai = dpnp .array (np_ai , dtype = idx_dt )
687
+ ia , iind = dpnp .array (a ), dpnp .array (ind )
683
688
684
689
for axis in range (ndim ):
685
- expected = numpy .take_along_axis (np_a , np_ai , axis )
686
- result = dpnp .take_along_axis (dp_a , dp_ai , axis )
690
+ result = dpnp .take_along_axis (ia , iind , axis )
691
+ expected = numpy .take_along_axis (a , ind , axis )
687
692
assert_array_equal (expected , result )
688
693
689
694
@pytest .mark .parametrize ("xp" , [numpy , dpnp ])
690
- def test_invalid (self , xp ):
695
+ def test_not_enough_indices (self , xp ):
691
696
a = xp .ones ((10 , 10 ))
692
- ai = xp .ones ((10 , 2 ), dtype = xp .intp )
693
-
694
- # not enough indices
695
697
assert_raises (ValueError , xp .take_along_axis , a , xp .array (1 ), axis = 1 )
696
698
697
- # bool arrays not allowed
698
- assert_raises (
699
- IndexError , xp .take_along_axis , a , ai .astype (bool ), axis = 1
700
- )
699
+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
700
+ @pytest .mark .parametrize ("dt" , [bool , numpy .float32 ])
701
+ def test_invalid_indices_dtype (self , xp , dt ):
702
+ a = xp .ones ((10 , 10 ))
703
+ ind = xp .ones ((10 , 2 ), dtype = dt )
704
+ assert_raises (IndexError , xp .take_along_axis , a , ind , axis = 1 )
701
705
702
- # float arrays not allowed
703
- assert_raises (
704
- IndexError , xp .take_along_axis , a , ai .astype (numpy .float32 ), axis = 1
705
- )
706
+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
707
+ def test_invalid_axis (self , xp ):
708
+ a = xp .ones ((10 , 10 ))
709
+ ind = xp .ones ((10 , 2 ), dtype = xp .intp )
710
+ assert_raises (AxisError , xp .take_along_axis , a , ind , axis = 10 )
706
711
707
- # invalid axis
708
- assert_raises (AxisError , xp .take_along_axis , a , ai , axis = 10 )
712
+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
713
+ def test_indices_ndim_axis_none (self , xp ):
714
+ a = xp .ones ((10 , 10 ))
715
+ ind = xp .ones ((10 , 2 ), dtype = xp .intp )
716
+ assert_raises (ValueError , xp .take_along_axis , a , ind , axis = None )
709
717
710
- @pytest .mark .parametrize ("arr_dt " , get_all_dtypes ())
718
+ @pytest .mark .parametrize ("a_dt " , get_all_dtypes (no_none = True ))
711
719
@pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
712
- def test_empty (self , arr_dt , idx_dt ):
713
- np_a = numpy .ones ((3 , 4 , 5 ), dtype = arr_dt )
714
- np_ai = numpy .ones ((3 , 0 , 5 ), dtype = idx_dt )
715
-
716
- dp_a = dpnp .array (np_a , dtype = arr_dt )
717
- dp_ai = dpnp .array (np_ai , dtype = idx_dt )
720
+ def test_empty (self , a_dt , idx_dt ):
721
+ a = numpy .ones ((3 , 4 , 5 ), dtype = a_dt )
722
+ ind = numpy .ones ((3 , 0 , 5 ), dtype = idx_dt )
723
+ ia , iind = dpnp .array (a ), dpnp .array (ind )
718
724
719
- expected = numpy .take_along_axis (np_a , np_ai , axis = 1 )
720
- result = dpnp .take_along_axis (dp_a , dp_ai , axis = 1 )
725
+ result = dpnp .take_along_axis (ia , iind , axis = 1 )
726
+ expected = numpy .take_along_axis (a , ind , axis = 1 )
721
727
assert_array_equal (expected , result )
722
728
723
- @pytest .mark .parametrize ("arr_dt " , get_all_dtypes ())
729
+ @pytest .mark .parametrize ("a_dt " , get_all_dtypes (no_none = True ))
724
730
@pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
725
- def test_broadcast (self , arr_dt , idx_dt ):
726
- np_a = numpy .ones ((3 , 4 , 1 ), dtype = arr_dt )
727
- np_ai = numpy .ones ((1 , 2 , 5 ), dtype = idx_dt )
728
-
729
- dp_a = dpnp .array (np_a , dtype = arr_dt )
730
- dp_ai = dpnp .array (np_ai , dtype = idx_dt )
731
+ def test_broadcast (self , a_dt , idx_dt ):
732
+ a = numpy .ones ((3 , 4 , 1 ), dtype = a_dt )
733
+ ind = numpy .ones ((1 , 2 , 5 ), dtype = idx_dt )
734
+ ia , iind = dpnp .array (a ), dpnp .array (ind )
731
735
732
- expected = numpy .take_along_axis (np_a , np_ai , axis = 1 )
733
- result = dpnp .take_along_axis (dp_a , dp_ai , axis = 1 )
736
+ result = dpnp .take_along_axis (ia , iind , axis = 1 )
737
+ expected = numpy .take_along_axis (a , ind , axis = 1 )
734
738
assert_array_equal (expected , result )
735
739
740
+ def test_mode_wrap (self ):
741
+ a = numpy .array ([- 2 , - 1 , 0 , 1 , 2 ])
742
+ ind = numpy .array ([- 2 , 2 , - 5 , 4 ])
743
+ ia , iind = dpnp .array (a ), dpnp .array (ind )
744
+
745
+ result = dpnp .take_along_axis (ia , iind , axis = 0 , mode = "wrap" )
746
+ expected = numpy .take_along_axis (a , ind , axis = 0 )
747
+ assert_array_equal (result , expected )
748
+
749
+ def test_mode_clip (self ):
750
+ a = dpnp .array ([- 2 , - 1 , 0 , 1 , 2 ])
751
+ ind = dpnp .array ([- 2 , 2 , - 5 , 4 ])
752
+
753
+ # numpy does not support keyword `mode`
754
+ result = dpnp .take_along_axis (a , ind , axis = 0 , mode = "clip" )
755
+ assert (result == dpnp .array ([- 2 , 0 , - 2 , 2 ])).all ()
756
+
736
757
737
758
@pytest .mark .usefixtures ("allow_fall_back_on_numpy" )
738
759
def test_choose ():
0 commit comments