@@ -109,6 +109,43 @@ def test_apply_along_axis_invalid_axis():
109
109
xp .apply_along_axis (xp .sum , axis , a )
110
110
111
111
112
+ class TestApplyOverAxes (unittest .TestCase ):
113
+ @testing .numpy_cupy_array_equal (type_check = has_support_aspect64 ())
114
+ def test_simple (self , xp ):
115
+ a = xp .arange (24 ).reshape (2 , 3 , 4 )
116
+ aoa_a = xp .apply_over_axes (xp .sum , a , [0 , 2 ])
117
+ return aoa_a
118
+
119
+ def test_apply_over_axis_invalid_0darr (self ):
120
+ # cupy will not accept 0darr, but numpy does
121
+ with pytest .raises (AxisError ):
122
+ a = cupy .array (42 )
123
+ cupy .apply_over_axes (cupy .sum , a , 0 )
124
+ # test for numpy, it can run without error
125
+ a = numpy .array (42 )
126
+ numpy .apply_over_axes (numpy .sum , a , 0 )
127
+
128
+ @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
129
+ def test_apply_over_axis_shape_preserve_func (self , xp ):
130
+ a = xp .arange (10 ).reshape (2 , 5 , 1 )
131
+
132
+ def normalize (arr , axis ):
133
+ """shape-preserve operation, return {x_i/sum(x)}"""
134
+ row_sums = arr .sum (axis = axis )
135
+ return a / row_sums [:, xp .newaxis ]
136
+
137
+ aoa_a = xp .apply_over_axes (normalize , a , 1 )
138
+ assert a .shape == aoa_a .shape
139
+ return aoa_a
140
+
141
+ def test_apply_over_axis_invalid_axis (self ):
142
+ for xp in [numpy , cupy ]:
143
+ a = xp .ones ((8 , 4 ))
144
+ axis = 3
145
+ with pytest .raises (AxisError ):
146
+ xp .apply_over_axes (xp .sum , a , axis )
147
+
148
+
112
149
class TestPutAlongAxis (unittest .TestCase ):
113
150
114
151
@testing .for_all_dtypes ()
0 commit comments