@@ -109,6 +109,44 @@ def test_apply_along_axis_invalid_axis():
109
109
xp .apply_along_axis (xp .sum , axis , a )
110
110
111
111
112
+ class TestApplyOverAxes (unittest .TestCase ):
113
+
114
+ @testing .numpy_cupy_array_equal (type_check = has_support_aspect64 ())
115
+ def test_simple (self , xp ):
116
+ a = xp .arange (24 ).reshape (2 , 3 , 4 )
117
+ aoa_a = xp .apply_over_axes (xp .sum , a , [0 , 2 ])
118
+ return aoa_a
119
+
120
+ def test_apply_over_axis_invalid_0darr (self ):
121
+ # cupy will not accept 0darr, but numpy does
122
+ with pytest .raises (AxisError ):
123
+ a = cupy .array (42 )
124
+ cupy .apply_over_axes (cupy .sum , a , 0 )
125
+ # test for numpy, it can run without error
126
+ a = numpy .array (42 )
127
+ numpy .apply_over_axes (numpy .sum , a , 0 )
128
+
129
+ @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
130
+ def test_apply_over_axis_shape_preserve_func (self , xp ):
131
+ a = xp .arange (10 ).reshape (2 , 5 , 1 )
132
+
133
+ def normalize (arr , axis ):
134
+ """shape-preserve operation, return {x_i/sum(x)}"""
135
+ row_sums = arr .sum (axis = axis )
136
+ return a / row_sums [:, xp .newaxis ]
137
+
138
+ aoa_a = xp .apply_over_axes (normalize , a , 1 )
139
+ assert a .shape == aoa_a .shape
140
+ return aoa_a
141
+
142
+ def test_apply_over_axis_invalid_axis (self ):
143
+ for xp in [numpy , cupy ]:
144
+ a = xp .ones ((8 , 4 ))
145
+ axis = 3
146
+ with pytest .raises (AxisError ):
147
+ xp .apply_over_axes (xp .sum , a , axis )
148
+
149
+
112
150
class TestPutAlongAxis (unittest .TestCase ):
113
151
114
152
@testing .for_all_dtypes ()
0 commit comments