Skip to content

Commit 904227e

Browse files
authored
simplifying a utility function for reduction operation (#2234)
In this PR, a utility function (`dpnp_wrap_reduction_call`) for reduction operation is simplified.
1 parent 584f0cb commit 904227e

File tree

5 files changed

+32
-57
lines changed

5 files changed

+32
-57
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _get_max_min(dtype):
152152
return f.max, f.min
153153

154154

155-
def _get_reduction_res_dt(a, dtype, _out):
155+
def _get_reduction_res_dt(a, dtype):
156156
"""Get a data type used by dpctl for result array in reduction function."""
157157

158158
if dtype is None:
@@ -1106,11 +1106,10 @@ def cumprod(a, axis=None, dtype=None, out=None):
11061106
usm_a = dpnp.get_usm_ndarray(a)
11071107

11081108
return dpnp_wrap_reduction_call(
1109-
a,
1109+
usm_a,
11101110
out,
11111111
dpt.cumulative_prod,
1112-
_get_reduction_res_dt,
1113-
usm_a,
1112+
_get_reduction_res_dt(a, dtype),
11141113
axis=axis,
11151114
dtype=dtype,
11161115
)
@@ -1196,11 +1195,10 @@ def cumsum(a, axis=None, dtype=None, out=None):
11961195
usm_a = dpnp.get_usm_ndarray(a)
11971196

11981197
return dpnp_wrap_reduction_call(
1199-
a,
1198+
usm_a,
12001199
out,
12011200
dpt.cumulative_sum,
1202-
_get_reduction_res_dt,
1203-
usm_a,
1201+
_get_reduction_res_dt(a, dtype),
12041202
axis=axis,
12051203
dtype=dtype,
12061204
)
@@ -1281,11 +1279,10 @@ def cumulative_prod(
12811279
"""
12821280

12831281
return dpnp_wrap_reduction_call(
1284-
x,
1282+
dpnp.get_usm_ndarray(x),
12851283
out,
12861284
dpt.cumulative_prod,
1287-
_get_reduction_res_dt,
1288-
dpnp.get_usm_ndarray(x),
1285+
_get_reduction_res_dt(x, dtype),
12891286
axis=axis,
12901287
dtype=dtype,
12911288
include_initial=include_initial,
@@ -1373,11 +1370,10 @@ def cumulative_sum(
13731370
"""
13741371

13751372
return dpnp_wrap_reduction_call(
1376-
x,
1373+
dpnp.get_usm_ndarray(x),
13771374
out,
13781375
dpt.cumulative_sum,
1379-
_get_reduction_res_dt,
1380-
dpnp.get_usm_ndarray(x),
1376+
_get_reduction_res_dt(x, dtype),
13811377
axis=axis,
13821378
dtype=dtype,
13831379
include_initial=include_initial,
@@ -3524,11 +3520,10 @@ def prod(
35243520
usm_a = dpnp.get_usm_ndarray(a)
35253521

35263522
return dpnp_wrap_reduction_call(
3527-
a,
3523+
usm_a,
35283524
out,
35293525
dpt.prod,
3530-
_get_reduction_res_dt,
3531-
usm_a,
3526+
_get_reduction_res_dt(a, dtype),
35323527
axis=axis,
35333528
dtype=dtype,
35343529
keepdims=keepdims,
@@ -4297,11 +4292,10 @@ def sum(
42974292

42984293
usm_a = dpnp.get_usm_ndarray(a)
42994294
return dpnp_wrap_reduction_call(
4300-
a,
4295+
usm_a,
43014296
out,
43024297
dpt.sum,
4303-
_get_reduction_res_dt,
4304-
usm_a,
4298+
_get_reduction_res_dt(a, dtype),
43054299
axis=axis,
43064300
dtype=dtype,
43074301
keepdims=keepdims,

dpnp/dpnp_iface_searching.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@
4848
__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"]
4949

5050

51-
def _get_search_res_dt(a, _dtype, out):
51+
def _get_search_res_dt(a, out):
5252
"""Get a data type used by dpctl for result array in search function."""
5353

5454
# get a data type used by dpctl for result array in search function
5555
res_dt = dti.default_device_index_type(a.sycl_device)
5656

5757
# numpy raises TypeError if "out" data type mismatch default index type
58-
if not dpnp.can_cast(out.dtype, res_dt, casting="safe"):
58+
if out is not None and not dpnp.can_cast(out.dtype, res_dt, casting="safe"):
5959
raise TypeError(
6060
f"Cannot cast from {out.dtype} to {res_dt} "
6161
"according to the rule safe."
@@ -143,11 +143,10 @@ def argmax(a, axis=None, out=None, *, keepdims=False):
143143

144144
usm_a = dpnp.get_usm_ndarray(a)
145145
return dpnp_wrap_reduction_call(
146-
a,
146+
usm_a,
147147
out,
148148
dpt.argmax,
149-
_get_search_res_dt,
150-
usm_a,
149+
_get_search_res_dt(a, out),
151150
axis=axis,
152151
keepdims=keepdims,
153152
)
@@ -234,11 +233,10 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
234233

235234
usm_a = dpnp.get_usm_ndarray(a)
236235
return dpnp_wrap_reduction_call(
237-
a,
236+
usm_a,
238237
out,
239238
dpt.argmin,
240-
_get_search_res_dt,
241-
usm_a,
239+
_get_search_res_dt(a, out),
242240
axis=axis,
243241
keepdims=keepdims,
244242
)

dpnp/dpnp_iface_statistics.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ def _count_reduce_items(arr, axis, where=True):
115115
return items
116116

117117

118-
def _get_comparison_res_dt(a, _dtype, _out):
119-
"""Get a data type used by dpctl for result array in comparison function."""
120-
121-
return a.dtype
122-
123-
124118
def amax(a, axis=None, out=None, keepdims=False, initial=None, where=True):
125119
"""
126120
Return the maximum of an array or maximum along an axis.
@@ -760,11 +754,10 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
760754
usm_a = dpnp.get_usm_ndarray(a)
761755

762756
return dpnp_wrap_reduction_call(
763-
a,
757+
usm_a,
764758
out,
765759
dpt.max,
766-
_get_comparison_res_dt,
767-
usm_a,
760+
a.dtype,
768761
axis=axis,
769762
keepdims=keepdims,
770763
)
@@ -1026,11 +1019,10 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
10261019
usm_a = dpnp.get_usm_ndarray(a)
10271020

10281021
return dpnp_wrap_reduction_call(
1029-
a,
1022+
usm_a,
10301023
out,
10311024
dpt.min,
1032-
_get_comparison_res_dt,
1033-
usm_a,
1025+
a.dtype,
10341026
axis=axis,
10351027
keepdims=keepdims,
10361028
)

dpnp/dpnp_iface_trigonometric.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
]
9999

100100

101-
def _get_accumulation_res_dt(a, dtype, _out):
101+
def _get_accumulation_res_dt(a, dtype):
102102
"""Get a dtype used by dpctl for result array in accumulation function."""
103103

104104
if dtype is None:
@@ -893,11 +893,10 @@ def cumlogsumexp(
893893
usm_x = dpnp.get_usm_ndarray(x)
894894

895895
return dpnp_wrap_reduction_call(
896-
x,
896+
usm_x,
897897
out,
898898
dpt.cumulative_logsumexp,
899-
_get_accumulation_res_dt,
900-
usm_x,
899+
_get_accumulation_res_dt(x, dtype),
901900
axis=axis,
902901
dtype=dtype,
903902
include_initial=include_initial,
@@ -1705,11 +1704,10 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
17051704

17061705
usm_x = dpnp.get_usm_ndarray(x)
17071706
return dpnp_wrap_reduction_call(
1708-
x,
1707+
usm_x,
17091708
out,
17101709
dpt.logsumexp,
1711-
_get_accumulation_res_dt,
1712-
usm_x,
1710+
_get_accumulation_res_dt(x, dtype),
17131711
axis=axis,
17141712
dtype=dtype,
17151713
keepdims=keepdims,
@@ -1952,11 +1950,10 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
19521950

19531951
usm_x = dpnp.get_usm_ndarray(x)
19541952
return dpnp_wrap_reduction_call(
1955-
x,
1953+
usm_x,
19561954
out,
19571955
dpt.reduce_hypot,
1958-
_get_accumulation_res_dt,
1959-
usm_x,
1956+
_get_accumulation_res_dt(x, dtype),
19601957
axis=axis,
19611958
dtype=dtype,
19621959
keepdims=keepdims,

dpnp/dpnp_utils/dpnp_utils_reduction.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
__all__ = ["dpnp_wrap_reduction_call"]
3030

3131

32-
def dpnp_wrap_reduction_call(
33-
a, out, _reduction_fn, _get_res_dt_fn, *args, **kwargs
34-
):
32+
def dpnp_wrap_reduction_call(usm_a, out, _reduction_fn, res_dt, **kwargs):
3533
"""Wrap a reduction call from dpctl.tensor interface."""
3634

3735
input_out = out
@@ -40,16 +38,12 @@ def dpnp_wrap_reduction_call(
4038
else:
4139
dpnp.check_supported_arrays_type(out)
4240

43-
# fetch dtype from the passed kwargs to the reduction call
44-
dtype = kwargs.get("dtype", None)
45-
4641
# dpctl requires strict data type matching of out array with the result
47-
res_dt = _get_res_dt_fn(a, dtype, out)
4842
if out.dtype != res_dt:
4943
out = dpnp.astype(out, dtype=res_dt, copy=False)
5044

5145
usm_out = dpnp.get_usm_ndarray(out)
5246

5347
kwargs["out"] = usm_out
54-
res_usm = _reduction_fn(*args, **kwargs)
48+
res_usm = _reduction_fn(usm_a, **kwargs)
5549
return dpnp.get_result_array(res_usm, input_out, casting="unsafe")

0 commit comments

Comments
 (0)