Skip to content

Commit 1f6df8e

Browse files
committed
Add third party tests for dpnp.apply_over_axes
1 parent 76f4360 commit 1f6df8e

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

dpnp/tests/third_party/cupy/lib_tests/test_shape_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,43 @@ def test_apply_along_axis_invalid_axis():
109109
xp.apply_along_axis(xp.sum, axis, a)
110110

111111

112+
class TestApplyOverAxes(unittest.TestCase):
113+
@testing.numpy_cupy_array_equal(type_check=has_support_aspect64())
114+
def test_simple(self, xp):
115+
a = xp.arange(24).reshape(2, 3, 4)
116+
aoa_a = xp.apply_over_axes(xp.sum, a, [0, 2])
117+
return aoa_a
118+
119+
def test_apply_over_axis_invalid_0darr(self):
120+
# cupy will not accept 0darr, but numpy does
121+
with pytest.raises(AxisError):
122+
a = cupy.array(42)
123+
cupy.apply_over_axes(cupy.sum, a, 0)
124+
# test for numpy, it can run without error
125+
a = numpy.array(42)
126+
numpy.apply_over_axes(numpy.sum, a, 0)
127+
128+
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
129+
def test_apply_over_axis_shape_preserve_func(self, xp):
130+
a = xp.arange(10).reshape(2, 5, 1)
131+
132+
def normalize(arr, axis):
133+
"""shape-preserve operation, return {x_i/sum(x)}"""
134+
row_sums = arr.sum(axis=axis)
135+
return a / row_sums[:, xp.newaxis]
136+
137+
aoa_a = xp.apply_over_axes(normalize, a, 1)
138+
assert a.shape == aoa_a.shape
139+
return aoa_a
140+
141+
def test_apply_over_axis_invalid_axis(self):
142+
for xp in [numpy, cupy]:
143+
a = xp.ones((8, 4))
144+
axis = 3
145+
with pytest.raises(AxisError):
146+
xp.apply_over_axes(xp.sum, a, axis)
147+
148+
112149
class TestPutAlongAxis(unittest.TestCase):
113150

114151
@testing.for_all_dtypes()

0 commit comments

Comments
 (0)