Skip to content

Commit 07babc0

Browse files
committed
Separate wrapping function to reduction utils
1 parent 5b2494d commit 07babc0

File tree

6 files changed

+180
-134
lines changed

6 files changed

+180
-134
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,8 @@
5555

5656
import dpnp
5757
import dpnp.backend.extensions.vm._vm_impl as vmi
58-
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
59-
from dpnp.dpnp_array import dpnp_array
60-
from dpnp.dpnp_utils import call_origin, get_usm_allocations
6158

59+
from .backend.extensions.sycl_ext import _sycl_ext_impl
6260
from .dpnp_algo import (
6361
dpnp_cumprod,
6462
dpnp_ediff1d,
@@ -81,7 +79,10 @@
8179
acceptance_fn_sign,
8280
acceptance_fn_subtract,
8381
)
82+
from .dpnp_array import dpnp_array
83+
from .dpnp_utils import call_origin, get_usm_allocations
8484
from .dpnp_utils.dpnp_utils_linearalgebra import dpnp_cross
85+
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
8586

8687
__all__ = [
8788
"abs",
@@ -158,33 +159,14 @@ def _append_to_diff_array(a, axis, combined, values):
158159
combined.append(values)
159160

160161

161-
def _wrap_reduction_call(a, dtype, out, _reduction_fn, *args, **kwargs):
162-
"""Wrap a call of reduction functions from dpctl.tensor interface."""
162+
def _get_reduction_res_dt(a, dtype, _out):
163+
"""Get a data type used by dpctl for result array in reduction function."""
163164

164-
input_out = out
165-
if out is None:
166-
usm_out = None
167-
else:
168-
dpnp.check_supported_arrays_type(out)
169-
170-
# get a data type used by dpctl for result array in reduction function
171-
if dtype is None:
172-
res_dt = dtu._default_accumulation_dtype(a.dtype, a.sycl_queue)
173-
else:
174-
res_dt = dpnp.dtype(dtype)
175-
res_dt = dtu._to_device_supported_dtype(res_dt, a.sycl_device)
165+
if dtype is None:
166+
return dtu._default_accumulation_dtype(a.dtype, a.sycl_queue)
176167

177-
# dpctl requires strict data type matching of out array with the result
178-
if out.dtype != res_dt:
179-
out = dpnp.astype(out, dtype=res_dt, copy=False)
180-
181-
usm_out = dpnp.get_usm_ndarray(out)
182-
183-
kwargs["dtype"] = dtype
184-
kwargs["out"] = usm_out
185-
res_usm = _reduction_fn(*args, **kwargs)
186-
res = dpnp_array._create_from_usm_ndarray(res_usm)
187-
return dpnp.get_result_array(res, input_out, casting="unsafe")
168+
dtype = dpnp.dtype(dtype)
169+
return dtu._to_device_supported_dtype(dtype, a.sycl_device)
188170

189171

190172
_ABS_DOCSTRING = """
@@ -868,19 +850,22 @@ def cumsum(a, axis=None, dtype=None, out=None):
868850
----------
869851
a : {dpnp.ndarray, usm_ndarray}
870852
Input array.
871-
axis : int, optional
872-
Axis along which the cumulative sum is computed. The default (``None``)
873-
is to compute the cumulative sum over the flattened array.
853+
axis : {int}, optional
854+
Axis along which the cumulative sum is computed. It defaults to compute
855+
the cumulative sum over the flattened array.
856+
Default: ``None``.
874857
dtype : {None, dtype}, optional
875858
Type of the returned array and of the accumulator in which the elements
876859
are summed. If `dtype` is not specified, it defaults to the dtype of
877860
`a`, unless `a` has an integer dtype with a precision less than that of
878861
the default platform integer. In that case, the default platform
879862
integer is used.
880-
out : {dpnp.ndarray, usm_ndarray}, optional
863+
Default: ``None``.
864+
out : {None, dpnp.ndarray, usm_ndarray}, optional
881865
Alternative output array in which to place the result. It must have the
882866
same shape and buffer length as the expected output but the type will
883867
be cast if necessary.
868+
Default: ``None``.
884869
885870
Returns
886871
-------
@@ -930,8 +915,14 @@ def cumsum(a, axis=None, dtype=None, out=None):
930915
else:
931916
usm_a = dpnp.get_usm_ndarray(a)
932917

933-
return _wrap_reduction_call(
934-
a, dtype, out, dpt.cumulative_sum, usm_a, axis=axis
918+
return dpnp_wrap_reduction_call(
919+
a,
920+
out,
921+
dpt.cumulative_sum,
922+
_get_reduction_res_dt,
923+
usm_a,
924+
axis=axis,
925+
dtype=dtype,
935926
)
936927

937928

@@ -945,13 +936,13 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
945936
----------
946937
a : {dpnp.ndarray, usm_ndarray}
947938
Input array
948-
n : int, optional
939+
n : {int}, optional
949940
The number of times the values differ. If ``zero``, the input
950941
is returned as-is.
951-
axis : int, optional
942+
axis : {int}, optional
952943
The axis along which the difference is taken, default is the
953944
last axis.
954-
prepend, append : {scalar, dpnp.ndarray, usm_ndarray}, optional
945+
prepend, append : {None, scalar, dpnp.ndarray, usm_ndarray}, optional
955946
Values to prepend or append to `a` along axis prior to
956947
performing the difference. Scalar values are expanded to
957948
arrays with length 1 in the direction of axis and the shape
@@ -2332,8 +2323,15 @@ def prod(
23322323
dpnp.check_limitations(initial=initial, where=where)
23332324
usm_a = dpnp.get_usm_ndarray(a)
23342325

2335-
return _wrap_reduction_call(
2336-
a, dtype, out, dpt.prod, usm_a, axis=axis, keepdims=keepdims
2326+
return dpnp_wrap_reduction_call(
2327+
a,
2328+
out,
2329+
dpt.prod,
2330+
_get_reduction_res_dt,
2331+
usm_a,
2332+
axis=axis,
2333+
dtype=dtype,
2334+
keepdims=keepdims,
23372335
)
23382336

23392337

@@ -2912,8 +2910,15 @@ def sum(
29122910
return result
29132911

29142912
usm_a = dpnp.get_usm_ndarray(a)
2915-
return _wrap_reduction_call(
2916-
a, dtype, out, dpt.sum, usm_a, axis=axis, keepdims=keepdims
2913+
return dpnp_wrap_reduction_call(
2914+
a,
2915+
out,
2916+
dpt.sum,
2917+
_get_reduction_res_dt,
2918+
usm_a,
2919+
axis=axis,
2920+
dtype=dtype,
2921+
keepdims=keepdims,
29172922
)
29182923

29192924

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
301301
----------
302302
a : {dpnp.ndarray, usm_ndarray}
303303
Input array.
304-
axis : int, optional
304+
axis : {int}, optional
305305
Axis along which the cumulative sum is computed. The default (``None``)
306306
is to compute the cumulative sum over the flattened array.
307307
dtype : {None, dtype}, optional
@@ -310,7 +310,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
310310
`a`, unless `a` has an integer dtype with a precision less than that of
311311
the default platform integer. In that case, the default platform
312312
integer is used.
313-
out : {dpnp.ndarray, usm_ndarray}, optional
313+
out : {None, dpnp.ndarray, usm_ndarray}, optional
314314
Alternative output array in which to place the result. It must have the
315315
same shape and buffer length as the expected output but the type will
316316
be cast if necessary.

dpnp/dpnp_iface_searching.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -49,39 +49,24 @@
4949
from .dpnp_utils import (
5050
get_usm_allocations,
5151
)
52+
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
5253

5354
__all__ = ["argmax", "argmin", "searchsorted", "where"]
5455

5556

56-
def _wrap_search_call(a, out, _search_fn, *args, **kwargs):
57-
"""Wrap a call of search functions from dpctl.tensor interface."""
57+
def _get_search_res_dt(a, _dtype, out):
58+
"""Get a data type used by dpctl for result array in search function."""
5859

59-
input_out = out
60-
if out is None:
61-
usm_out = None
62-
else:
63-
dpnp.check_supported_arrays_type(out)
64-
65-
# get a data type used by dpctl for result array in search function
66-
res_dt = dti.default_device_index_type(a.sycl_device)
67-
68-
# numpy raises TypeError if "out" data type mismatch default index type
69-
if not dpnp.can_cast(out.dtype, res_dt, casting="safe"):
70-
raise TypeError(
71-
f"Cannot cast from {out.dtype} to {res_dt} "
72-
"according to the rule safe."
73-
)
74-
75-
# dpctl requires strict data type matching of out array with the result
76-
if out.dtype != res_dt:
77-
out = dpnp.astype(out, dtype=res_dt, copy=False)
78-
79-
usm_out = dpnp.get_usm_ndarray(out)
60+
# get a data type used by dpctl for result array in search function
61+
res_dt = dti.default_device_index_type(a.sycl_device)
8062

81-
kwargs["out"] = usm_out
82-
res_usm = _search_fn(*args, **kwargs)
83-
res = dpnp_array._create_from_usm_ndarray(res_usm)
84-
return dpnp.get_result_array(res, input_out, casting="unsafe")
63+
# numpy raises TypeError if "out" data type mismatch default index type
64+
if not dpnp.can_cast(out.dtype, res_dt, casting="safe"):
65+
raise TypeError(
66+
f"Cannot cast from {out.dtype} to {res_dt} "
67+
"according to the rule safe."
68+
)
69+
return res_dt
8570

8671

8772
def argmax(a, axis=None, out=None, *, keepdims=False):
@@ -163,8 +148,14 @@ def argmax(a, axis=None, out=None, *, keepdims=False):
163148
"""
164149

165150
usm_a = dpnp.get_usm_ndarray(a)
166-
return _wrap_search_call(
167-
a, out, dpt.argmax, usm_a, axis=axis, keepdims=keepdims
151+
return dpnp_wrap_reduction_call(
152+
a,
153+
out,
154+
dpt.argmax,
155+
_get_search_res_dt,
156+
usm_a,
157+
axis=axis,
158+
keepdims=keepdims,
168159
)
169160

170161

@@ -248,8 +239,14 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
248239
"""
249240

250241
usm_a = dpnp.get_usm_ndarray(a)
251-
return _wrap_search_call(
252-
a, out, dpt.argmin, usm_a, axis=axis, keepdims=keepdims
242+
return dpnp_wrap_reduction_call(
243+
a,
244+
out,
245+
dpt.argmin,
246+
_get_search_res_dt,
247+
usm_a,
248+
axis=axis,
249+
keepdims=keepdims,
253250
)
254251

255252

dpnp/dpnp_iface_statistics.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
call_origin,
5555
get_usm_allocations,
5656
)
57+
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
5758
from .dpnp_utils.dpnp_utils_statistics import (
5859
dpnp_cov,
5960
)
@@ -118,28 +119,10 @@ def _count_reduce_items(arr, axis, where=True):
118119
return items
119120

120121

121-
def _wrap_comparison_call(a, out, _comparison_fn, *args, **kwargs):
122-
"""Wrap a call of comparison functions from dpctl.tensor interface."""
122+
def _get_comparison_res_dt(a, _dtype, _out):
123+
"""Get a data type used by dpctl for result array in comparison function."""
123124

124-
input_out = out
125-
if out is None:
126-
usm_out = None
127-
else:
128-
dpnp.check_supported_arrays_type(out)
129-
130-
# get dtype used by dpctl for result array in comparison function
131-
res_dt = a.dtype
132-
133-
# dpctl requires strict data type matching of out array with the result
134-
if out.dtype != res_dt:
135-
out = dpnp.astype(out, dtype=res_dt, copy=False)
136-
137-
usm_out = dpnp.get_usm_ndarray(out)
138-
139-
kwargs["out"] = usm_out
140-
res_usm = _comparison_fn(*args, **kwargs)
141-
res = dpnp_array._create_from_usm_ndarray(res_usm)
142-
return dpnp.get_result_array(res, input_out, casting="unsafe")
125+
return a.dtype
143126

144127

145128
def amax(a, axis=None, out=None, keepdims=False, initial=None, where=True):
@@ -548,8 +531,14 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
548531
dpnp.check_limitations(initial=initial, where=where)
549532
usm_a = dpnp.get_usm_ndarray(a)
550533

551-
return _wrap_comparison_call(
552-
a, out, dpt.max, usm_a, axis=axis, keepdims=keepdims
534+
return dpnp_wrap_reduction_call(
535+
a,
536+
out,
537+
dpt.max,
538+
_get_comparison_res_dt,
539+
usm_a,
540+
axis=axis,
541+
keepdims=keepdims,
553542
)
554543

555544

@@ -757,8 +746,14 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
757746
dpnp.check_limitations(initial=initial, where=where)
758747
usm_a = dpnp.get_usm_ndarray(a)
759748

760-
return _wrap_comparison_call(
761-
a, out, dpt.min, usm_a, axis=axis, keepdims=keepdims
749+
return dpnp_wrap_reduction_call(
750+
a,
751+
out,
752+
dpt.min,
753+
_get_comparison_res_dt,
754+
usm_a,
755+
axis=axis,
756+
keepdims=keepdims,
762757
)
763758

764759

0 commit comments

Comments
 (0)