Skip to content

Commit 24df081

Browse files
Review comments, bug fixing and tests
1 parent fdec1ee commit 24df081

File tree

4 files changed

+263
-9
lines changed

4 files changed

+263
-9
lines changed

dpnp/backend/extensions/sycl_ext/sum_mean.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,28 @@ sycl::event mean_over_axis(usm_ndarray input,
7878
return mean_fn(input, output, depends);
7979
}
8080

81-
py::cpp_function get_sum_over_axis(usm_ndarray input, usm_ndarray output)
81+
py::object get_sum_over_axis(usm_ndarray input, usm_ndarray output)
8282
{
8383
if (not sycl_ext::check_limitations(input, output))
84-
return nullptr;
84+
return py::none();
8585

86-
return (*sum_dispatcher)({input, output});
86+
auto sum = (*sum_dispatcher)({input, output});
87+
if (sum == nullptr)
88+
return py::none();
89+
90+
return py::cpp_function(sum);
8791
}
8892

89-
py::cpp_function get_mean_over_axis(usm_ndarray input, usm_ndarray output)
93+
py::object get_mean_over_axis(usm_ndarray input, usm_ndarray output)
9094
{
9195
if (not sycl_ext::check_limitations(input, output))
92-
return nullptr;
96+
return py::none();
97+
98+
auto mean = (*mean_dispatcher)({input, output});
99+
if (mean == nullptr)
100+
return py::none();
93101

94-
return (*mean_dispatcher)({input, output});
102+
return py::cpp_function(mean);
95103
}
96104

97105
PYBIND11_MODULE(_sycl_ext_impl, m)

dpnp/backend/extensions/sycl_ext/sum_mean.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ struct sum_mean
171171
throw py::value_error(
172172
"Input array axis 1 size must match output array size");
173173

174+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
175+
if (overlap(in, out))
176+
throw py::value_error("Input and output array are overlapped");
177+
174178
check_limitations(in, out, true);
175179
}
176180
};
@@ -209,7 +213,7 @@ using MeanOverAxisContigDispatcher =
209213
dpctl::tensor::usm_ndarray,
210214
UsmArrayMatcher,
211215
SumInputTypes,
212-
SumOutputTypes>;
216+
MeanOutputTypes>;
213217
} // namespace sycl_ext
214218
} // namespace ext
215219
} // namespace backend

dpnp/dpnp_iface_mathematical.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,15 +1828,19 @@ def sum(
18281828
elif where is not True:
18291829
pass
18301830
else:
1831-
if axis == (0,) and len(x.shape) == 2:
1831+
if axis == (0,) and len(x.shape) == 2 and not keepdims:
18321832
from dpctl.tensor._reduction import _default_reduction_dtype
18331833

18341834
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
18351835

18361836
input = dpnp.get_usm_ndarray(x)
18371837

18381838
queue = input.sycl_queue
1839-
out_dtype = _default_reduction_dtype(input.dtype, queue)
1839+
out_dtype = (
1840+
_default_reduction_dtype(input.dtype, queue)
1841+
if dtype is None
1842+
else dtype
1843+
)
18401844
output = dpt.empty(
18411845
input.shape[1], dtype=out_dtype, sycl_queue=queue
18421846
)

tests/test_extensions.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from itertools import product
2+
3+
import dpctl
4+
import dpctl.tensor as dpt
5+
import pytest
6+
7+
import dpnp
8+
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
9+
10+
all_devices = [
11+
device for device in dpctl.get_devices() if device.is_cpu or device.is_gpu
12+
]
13+
sum_supported_input_dtypes = [
14+
dpt.dtype("i1"),
15+
dpt.dtype("u1"),
16+
dpt.dtype("i2"),
17+
dpt.dtype("u2"),
18+
dpt.dtype("i4"),
19+
dpt.dtype("u4"),
20+
dpt.dtype("i8"),
21+
dpt.dtype("u8"),
22+
dpt.float32,
23+
dpt.float64,
24+
]
25+
sum_supported_output_dtypes = [
26+
dpt.dtype("i4"),
27+
dpt.dtype("u4"),
28+
dpt.dtype("i8"),
29+
dpt.dtype("u8"),
30+
dpt.float32,
31+
dpt.float64,
32+
]
33+
34+
mean_supported_output_dtypes = [dpt.float32, dpt.float64]
35+
sum_without_mean_supported_dtypes = list(
36+
set(sum_supported_output_dtypes) - set(mean_supported_output_dtypes)
37+
)
38+
39+
sum_unsupported_input_dtypes = [
40+
dpt.bool,
41+
dpt.float16,
42+
dpt.complex64,
43+
dpt.complex128,
44+
]
45+
sum_unsupported_output_dtypes = [
46+
dpt.bool,
47+
dpt.dtype("i1"),
48+
dpt.dtype("u1"),
49+
dpt.dtype("i2"),
50+
dpt.dtype("u2"),
51+
dpt.float16,
52+
dpt.complex64,
53+
dpt.complex128,
54+
]
55+
56+
mean_unsupported_output_dtypes = [
57+
dpt.bool,
58+
dpt.dtype("i1"),
59+
dpt.dtype("u1"),
60+
dpt.dtype("i2"),
61+
dpt.dtype("u2"),
62+
dpt.dtype("i4"),
63+
dpt.dtype("u4"),
64+
dpt.dtype("i8"),
65+
dpt.dtype("u8"),
66+
dpt.float16,
67+
dpt.complex64,
68+
dpt.complex128,
69+
]
70+
71+
72+
sum_only = [_sycl_ext_impl._get_sum_over_axis_0]
73+
mean_only = [_sycl_ext_impl._get_mean_over_axis_0]
74+
mean_sum = sum_only + mean_only
75+
76+
77+
def supported_by_device(device, typ):
78+
if typ == dpt.float64 or typ == dpt.complex128:
79+
return device.has_aspect_fp64
80+
81+
if typ == dpt.float16:
82+
return device.has_aspect_fp16
83+
84+
return True
85+
86+
87+
def skip_unsupported(device, typ):
88+
if not supported_by_device(device, typ):
89+
pytest.skip(f"{typ} type is not supported by {device}")
90+
91+
92+
@pytest.mark.parametrize(
93+
"func, device, input_type, output_type",
94+
product(
95+
mean_sum,
96+
all_devices,
97+
sum_supported_input_dtypes,
98+
mean_supported_output_dtypes,
99+
),
100+
)
101+
def test_mean_sum_over_axis_0_supported_types(
102+
func, device, input_type, output_type
103+
):
104+
skip_unsupported(device, input_type)
105+
skip_unsupported(device, output_type)
106+
107+
height = 20
108+
width = 10
109+
110+
input = dpt.empty((height, width), dtype=input_type, device=device)
111+
output = dpt.empty(width, dtype=output_type, device=device)
112+
113+
assert func(input, output) is not None
114+
115+
116+
@pytest.mark.parametrize(
117+
"func, device, input_type, output_type",
118+
product(
119+
sum_only,
120+
all_devices,
121+
sum_supported_input_dtypes,
122+
sum_without_mean_supported_dtypes,
123+
),
124+
)
125+
def test_sum_over_axis_0_supported_types(func, device, input_type, output_type):
126+
skip_unsupported(device, input_type)
127+
skip_unsupported(device, output_type)
128+
129+
height = 20
130+
width = 10
131+
132+
input = dpt.empty((height, width), dtype=input_type, device=device)
133+
output = dpt.empty(width, dtype=output_type, device=device)
134+
135+
assert func(input, output) is not None
136+
137+
138+
@pytest.mark.parametrize(
139+
"func, device, input_type, output_type",
140+
product(mean_sum, all_devices, sum_unsupported_input_dtypes, [dpt.float32]),
141+
)
142+
def test_mean_sum_over_axis_0_unsupported_in_types(
143+
func, device, input_type, output_type
144+
):
145+
skip_unsupported(device, input_type)
146+
skip_unsupported(device, output_type)
147+
148+
height = 1
149+
width = 1
150+
151+
input = dpt.empty((height, width), dtype=input_type, device=device)
152+
output = dpt.empty(width, dtype=output_type, device=device)
153+
154+
assert func(input, output) is None
155+
156+
157+
@pytest.mark.parametrize(
158+
"func, device, input_type, output_type",
159+
product(
160+
sum_only, all_devices, [dpt.float32], sum_unsupported_output_dtypes
161+
),
162+
)
163+
def test_sum_over_axis_0_unsupported_out_types(
164+
func, device, input_type, output_type
165+
):
166+
skip_unsupported(device, input_type)
167+
skip_unsupported(device, output_type)
168+
169+
height = 1
170+
width = 1
171+
172+
input = dpt.empty((height, width), dtype=input_type, device=device)
173+
output = dpt.empty(width, dtype=output_type, device=device)
174+
175+
assert func(input, output) is None
176+
177+
178+
@pytest.mark.parametrize(
179+
"func, device, input_type, output_type",
180+
product(
181+
mean_only, all_devices, [dpt.float32], mean_unsupported_output_dtypes
182+
),
183+
)
184+
def test_mean_over_axis_0_unsupported_out_types(
185+
func, device, input_type, output_type
186+
):
187+
skip_unsupported(device, input_type)
188+
skip_unsupported(device, output_type)
189+
190+
height = 1
191+
width = 1
192+
193+
input = dpt.empty((height, width), dtype=input_type, device=device)
194+
output = dpt.empty(width, dtype=output_type, device=device)
195+
196+
if func(input, output):
197+
print(output_type)
198+
assert func(input, output) is None
199+
200+
201+
@pytest.mark.parametrize(
202+
"func, device, input_type, output_type",
203+
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
204+
)
205+
def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
206+
skip_unsupported(device, input_type)
207+
skip_unsupported(device, output_type)
208+
209+
height = 20
210+
width = 10
211+
212+
input = dpt.empty((height, width), dtype=input_type, device=device).T
213+
output = dpt.empty(width, dtype=output_type, device=device)
214+
215+
if func(input, output):
216+
print(output_type)
217+
assert func(input, output) is None
218+
219+
220+
@pytest.mark.parametrize(
221+
"func, device, input_type, output_type",
222+
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
223+
)
224+
def test_mean_over_axis_0_f_contig_output(
225+
func, device, input_type, output_type
226+
):
227+
skip_unsupported(device, input_type)
228+
skip_unsupported(device, output_type)
229+
230+
height = 1
231+
width = 10
232+
233+
input = dpt.empty((height, 10), dtype=input_type, device=device)
234+
output = dpt.empty(20, dtype=output_type, device=device)[::2]
235+
236+
if func(input, output):
237+
print(output_type)
238+
assert func(input, output) is None

0 commit comments

Comments
 (0)