Skip to content

Commit 17a3251

Browse files
authored
Add third party tests for dpnp.apply_over_axes (#2401)
This PR add third party tests for `dpnp.apply_over_axes`.
1 parent f74b5a8 commit 17a3251

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

dpnp/tests/third_party/cupy/lib_tests/test_shape_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,44 @@ def test_apply_along_axis_invalid_axis():
109109
xp.apply_along_axis(xp.sum, axis, a)
110110

111111

112+
class TestApplyOverAxes(unittest.TestCase):
113+
114+
@testing.numpy_cupy_array_equal(type_check=has_support_aspect64())
115+
def test_simple(self, xp):
116+
a = xp.arange(24).reshape(2, 3, 4)
117+
aoa_a = xp.apply_over_axes(xp.sum, a, [0, 2])
118+
return aoa_a
119+
120+
def test_apply_over_axis_invalid_0darr(self):
121+
# cupy will not accept 0darr, but numpy does
122+
with pytest.raises(AxisError):
123+
a = cupy.array(42)
124+
cupy.apply_over_axes(cupy.sum, a, 0)
125+
# test for numpy, it can run without error
126+
a = numpy.array(42)
127+
numpy.apply_over_axes(numpy.sum, a, 0)
128+
129+
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
130+
def test_apply_over_axis_shape_preserve_func(self, xp):
131+
a = xp.arange(10).reshape(2, 5, 1)
132+
133+
def normalize(arr, axis):
134+
"""shape-preserve operation, return {x_i/sum(x)}"""
135+
row_sums = arr.sum(axis=axis)
136+
return a / row_sums[:, xp.newaxis]
137+
138+
aoa_a = xp.apply_over_axes(normalize, a, 1)
139+
assert a.shape == aoa_a.shape
140+
return aoa_a
141+
142+
def test_apply_over_axis_invalid_axis(self):
143+
for xp in [numpy, cupy]:
144+
a = xp.ones((8, 4))
145+
axis = 3
146+
with pytest.raises(AxisError):
147+
xp.apply_over_axes(xp.sum, a, axis)
148+
149+
112150
class TestPutAlongAxis(unittest.TestCase):
113151

114152
@testing.for_all_dtypes()

0 commit comments

Comments
 (0)