55
55
56
56
import dpnp
57
57
import dpnp .backend .extensions .vm ._vm_impl as vmi
58
- from dpnp .backend .extensions .sycl_ext import _sycl_ext_impl
59
- from dpnp .dpnp_array import dpnp_array
60
- from dpnp .dpnp_utils import call_origin , get_usm_allocations
61
58
59
+ from .backend .extensions .sycl_ext import _sycl_ext_impl
62
60
from .dpnp_algo import (
63
61
dpnp_cumprod ,
64
62
dpnp_ediff1d ,
81
79
acceptance_fn_sign ,
82
80
acceptance_fn_subtract ,
83
81
)
82
+ from .dpnp_array import dpnp_array
83
+ from .dpnp_utils import call_origin , get_usm_allocations
84
84
from .dpnp_utils .dpnp_utils_linearalgebra import dpnp_cross
85
+ from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
85
86
86
87
__all__ = [
87
88
"abs" ,
@@ -158,33 +159,14 @@ def _append_to_diff_array(a, axis, combined, values):
158
159
combined .append (values )
159
160
160
161
161
- def _wrap_reduction_call (a , dtype , out , _reduction_fn , * args , ** kwargs ):
162
- """Wrap a call of reduction functions from dpctl.tensor interface ."""
162
+ def _get_reduction_res_dt (a , dtype , _out ):
163
+ """Get a data type used by dpctl for result array in reduction function ."""
163
164
164
- input_out = out
165
- if out is None :
166
- usm_out = None
167
- else :
168
- dpnp .check_supported_arrays_type (out )
169
-
170
- # get a data type used by dpctl for result array in reduction function
171
- if dtype is None :
172
- res_dt = dtu ._default_accumulation_dtype (a .dtype , a .sycl_queue )
173
- else :
174
- res_dt = dpnp .dtype (dtype )
175
- res_dt = dtu ._to_device_supported_dtype (res_dt , a .sycl_device )
165
+ if dtype is None :
166
+ return dtu ._default_accumulation_dtype (a .dtype , a .sycl_queue )
176
167
177
- # dpctl requires strict data type matching of out array with the result
178
- if out .dtype != res_dt :
179
- out = dpnp .astype (out , dtype = res_dt , copy = False )
180
-
181
- usm_out = dpnp .get_usm_ndarray (out )
182
-
183
- kwargs ["dtype" ] = dtype
184
- kwargs ["out" ] = usm_out
185
- res_usm = _reduction_fn (* args , ** kwargs )
186
- res = dpnp_array ._create_from_usm_ndarray (res_usm )
187
- return dpnp .get_result_array (res , input_out , casting = "unsafe" )
168
+ dtype = dpnp .dtype (dtype )
169
+ return dtu ._to_device_supported_dtype (dtype , a .sycl_device )
188
170
189
171
190
172
_ABS_DOCSTRING = """
@@ -868,19 +850,22 @@ def cumsum(a, axis=None, dtype=None, out=None):
868
850
----------
869
851
a : {dpnp.ndarray, usm_ndarray}
870
852
Input array.
871
- axis : int, optional
872
- Axis along which the cumulative sum is computed. The default (``None``)
873
- is to compute the cumulative sum over the flattened array.
853
+ axis : {int}, optional
854
+ Axis along which the cumulative sum is computed. It defaults to compute
855
+ the cumulative sum over the flattened array.
856
+ Default: ``None``.
874
857
dtype : {None, dtype}, optional
875
858
Type of the returned array and of the accumulator in which the elements
876
859
are summed. If `dtype` is not specified, it defaults to the dtype of
877
860
`a`, unless `a` has an integer dtype with a precision less than that of
878
861
the default platform integer. In that case, the default platform
879
862
integer is used.
880
- out : {dpnp.ndarray, usm_ndarray}, optional
863
+ Default: ``None``.
864
+ out : {None, dpnp.ndarray, usm_ndarray}, optional
881
865
Alternative output array in which to place the result. It must have the
882
866
same shape and buffer length as the expected output but the type will
883
867
be cast if necessary.
868
+ Default: ``None``.
884
869
885
870
Returns
886
871
-------
@@ -930,8 +915,14 @@ def cumsum(a, axis=None, dtype=None, out=None):
930
915
else :
931
916
usm_a = dpnp .get_usm_ndarray (a )
932
917
933
- return _wrap_reduction_call (
934
- a , dtype , out , dpt .cumulative_sum , usm_a , axis = axis
918
+ return dpnp_wrap_reduction_call (
919
+ a ,
920
+ out ,
921
+ dpt .cumulative_sum ,
922
+ _get_reduction_res_dt ,
923
+ usm_a ,
924
+ axis = axis ,
925
+ dtype = dtype ,
935
926
)
936
927
937
928
@@ -945,13 +936,13 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
945
936
----------
946
937
a : {dpnp.ndarray, usm_ndarray}
947
938
Input array
948
- n : int, optional
939
+ n : { int} , optional
949
940
The number of times the values differ. If ``zero``, the input
950
941
is returned as-is.
951
- axis : int, optional
942
+ axis : { int} , optional
952
943
The axis along which the difference is taken, default is the
953
944
last axis.
954
- prepend, append : {scalar, dpnp.ndarray, usm_ndarray}, optional
945
+ prepend, append : {None, scalar, dpnp.ndarray, usm_ndarray}, optional
955
946
Values to prepend or append to `a` along axis prior to
956
947
performing the difference. Scalar values are expanded to
957
948
arrays with length 1 in the direction of axis and the shape
@@ -2332,8 +2323,15 @@ def prod(
2332
2323
dpnp .check_limitations (initial = initial , where = where )
2333
2324
usm_a = dpnp .get_usm_ndarray (a )
2334
2325
2335
- return _wrap_reduction_call (
2336
- a , dtype , out , dpt .prod , usm_a , axis = axis , keepdims = keepdims
2326
+ return dpnp_wrap_reduction_call (
2327
+ a ,
2328
+ out ,
2329
+ dpt .prod ,
2330
+ _get_reduction_res_dt ,
2331
+ usm_a ,
2332
+ axis = axis ,
2333
+ dtype = dtype ,
2334
+ keepdims = keepdims ,
2337
2335
)
2338
2336
2339
2337
@@ -2912,8 +2910,15 @@ def sum(
2912
2910
return result
2913
2911
2914
2912
usm_a = dpnp .get_usm_ndarray (a )
2915
- return _wrap_reduction_call (
2916
- a , dtype , out , dpt .sum , usm_a , axis = axis , keepdims = keepdims
2913
+ return dpnp_wrap_reduction_call (
2914
+ a ,
2915
+ out ,
2916
+ dpt .sum ,
2917
+ _get_reduction_res_dt ,
2918
+ usm_a ,
2919
+ axis = axis ,
2920
+ dtype = dtype ,
2921
+ keepdims = keepdims ,
2917
2922
)
2918
2923
2919
2924
0 commit comments