Skip to content

Commit 23c79c0

Browse files
Add more tests for dpnp.sum and sum_over_axis_0 extension
1 parent af1af29 commit 23c79c0

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

tests/test_extensions.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,16 @@ def test_mean_over_axis_0_unsupported_out_types(
193193
input = dpt.empty((height, width), dtype=input_type, device=device)
194194
output = dpt.empty(width, dtype=output_type, device=device)
195195

196-
if func(input, output):
197-
print(output_type)
198196
assert func(input, output) is None
199197

200198

201199
@pytest.mark.parametrize(
202200
"func, device, input_type, output_type",
203201
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
204202
)
205-
def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
203+
def test_mean_sum_over_axis_0_f_contig_input(
204+
func, device, input_type, output_type
205+
):
206206
skip_unsupported(device, input_type)
207207
skip_unsupported(device, output_type)
208208

@@ -212,16 +212,14 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
212212
input = dpt.empty((height, width), dtype=input_type, device=device).T
213213
output = dpt.empty(width, dtype=output_type, device=device)
214214

215-
if func(input, output):
216-
print(output_type)
217215
assert func(input, output) is None
218216

219217

220218
@pytest.mark.parametrize(
221219
"func, device, input_type, output_type",
222220
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
223221
)
224-
def test_mean_over_axis_0_f_contig_output(
222+
def test_mean_sum_over_axis_0_f_contig_output(
225223
func, device, input_type, output_type
226224
):
227225
skip_unsupported(device, input_type)
@@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output(
230228
height = 1
231229
width = 10
232230

233-
input = dpt.empty((height, 10), dtype=input_type, device=device)
234-
output = dpt.empty(20, dtype=output_type, device=device)[::2]
231+
input = dpt.empty((height, width), dtype=input_type, device=device)
232+
output = dpt.empty(width * 2, dtype=output_type, device=device)[::2]
233+
234+
assert func(input, output) is None
235+
236+
237+
@pytest.mark.parametrize(
238+
"func, device, input_type, output_type",
239+
product(mean_sum, all_devices, [dpt.float32], [dpt.float32, dpt.float64]),
240+
)
241+
def test_mean_sum_over_axis_0_big_output(func, device, input_type, output_type):
242+
skip_unsupported(device, input_type)
243+
skip_unsupported(device, output_type)
244+
245+
local_mem_size = device.local_mem_size
246+
height = 1
247+
width = 1 + local_mem_size // output_type.itemsize
248+
249+
input = dpt.empty((height, width), dtype=input_type, device=device)
250+
output = dpt.empty(width, dtype=output_type, device=device)
235251

236-
if func(input, output):
237-
print(output_type)
238252
assert func(input, output) is None

tests/test_mathematical.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from itertools import permutations
2+
13
import numpy
24
import pytest
35
from numpy.testing import (
@@ -1042,23 +1044,46 @@ def test_sum_empty_out(dtype):
10421044

10431045

10441046
@pytest.mark.parametrize(
1045-
"shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)]
1047+
"shape",
1048+
[
1049+
(),
1050+
(1, 2, 3),
1051+
(1, 0, 2),
1052+
(10,),
1053+
(3, 3, 3),
1054+
(5, 5),
1055+
(0, 6),
1056+
(10, 1),
1057+
(1, 10),
1058+
],
10461059
)
10471060
@pytest.mark.parametrize(
10481061
"dtype_in", get_all_dtypes(no_complex=True, no_bool=True)
10491062
)
10501063
@pytest.mark.parametrize(
10511064
"dtype_out", get_all_dtypes(no_complex=True, no_bool=True)
10521065
)
1053-
def test_sum(shape, dtype_in, dtype_out):
1054-
a_np = numpy.ones(shape, dtype=dtype_in)
1055-
a = dpnp.ones(shape, dtype=dtype_in)
1056-
axes = [None, 0, 1, 2]
1066+
@pytest.mark.parametrize("transpose", [True, False])
1067+
@pytest.mark.parametrize("keepdims", [False])
1068+
def test_sum(shape, dtype_in, dtype_out, transpose, keepdims):
1069+
size = numpy.prod(shape)
1070+
a_np = numpy.arange(size).astype(dtype_in).reshape(shape)
1071+
a = dpnp.asarray(a_np)
1072+
1073+
if transpose:
1074+
a_np = a_np.T
1075+
a = a.T
1076+
1077+
axes_range = list(numpy.arange(len(shape)))
1078+
axes = [None]
1079+
axes += axes_range
1080+
axes += permutations(axes_range, 2)
1081+
axes.append(tuple(axes_range))
1082+
10571083
for axis in axes:
1058-
if axis is None or axis < a.ndim:
1059-
numpy_res = a_np.sum(axis=axis, dtype=dtype_out)
1060-
dpnp_res = a.sum(axis=axis, dtype=dtype_out)
1061-
assert_array_equal(numpy_res, dpnp_res.asnumpy())
1084+
numpy_res = a_np.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
1085+
dpnp_res = a.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
1086+
assert_array_equal(numpy_res, dpnp_res.asnumpy())
10621087

10631088

10641089
class TestMean:

0 commit comments

Comments
 (0)