Skip to content

Commit b6e8b05

Browse files
committed
Extend dpnp.get_result_array() to accept dpt.usm_ndarray
1 parent 0f633af commit b6e8b05

File tree

4 files changed

+17
-30
lines changed

4 files changed

+17
-30
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151

5252
import dpnp
5353
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
54-
from dpnp.dpnp_array import dpnp_array
5554

5655
__all__ = [
5756
"all",
@@ -167,13 +166,11 @@ def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
167166

168167
dpnp.check_limitations(where=where)
169168

170-
dpt_array = dpnp.get_usm_ndarray(a)
171-
result = dpnp_array._create_from_usm_ndarray(
172-
dpt.all(dpt_array, axis=axis, keepdims=keepdims)
173-
)
169+
usm_a = dpnp.get_usm_ndarray(a)
170+
usm_res = dpt.all(usm_a, axis=axis, keepdims=keepdims)
171+
174172
# TODO: temporary solution until dpt.all supports out parameter
175-
result = dpnp.get_result_array(result, out)
176-
return result
173+
return dpnp.get_result_array(usm_res, out)
177174

178175

179176
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
@@ -333,13 +330,11 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):
333330

334331
dpnp.check_limitations(where=where)
335332

336-
dpt_array = dpnp.get_usm_ndarray(a)
337-
result = dpnp_array._create_from_usm_ndarray(
338-
dpt.any(dpt_array, axis=axis, keepdims=keepdims)
339-
)
333+
usm_a = dpnp.get_usm_ndarray(a)
334+
usm_res = dpt.any(usm_a, axis=axis, keepdims=keepdims)
335+
340336
# TODO: temporary solution until dpt.any supports out parameter
341-
result = dpnp.get_result_array(result, out)
342-
return result
337+
return dpnp.get_result_array(usm_res, out)
343338

344339

345340
_EQUAL_DOCSTRING = """

dpnp/dpnp_iface_searching.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,5 @@ def where(condition, x=None, y=None, /, *, order="K", out=None):
391391
usm_condition = dpnp.get_usm_ndarray(condition)
392392

393393
usm_out = None if out is None else dpnp.get_usm_ndarray(out)
394-
result = dpnp_array._create_from_usm_ndarray(
395-
dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out)
396-
)
397-
return dpnp.get_result_array(result, out)
394+
usm_res = dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out)
395+
return dpnp.get_result_array(usm_res, out)

dpnp/dpnp_iface_statistics.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -904,11 +904,9 @@ def std(
904904
)
905905
dpnp.sqrt(result, out=result)
906906
else:
907-
dpt_array = dpnp.get_usm_ndarray(a)
908-
result = dpnp_array._create_from_usm_ndarray(
909-
dpt.std(dpt_array, axis=axis, correction=ddof, keepdims=keepdims)
910-
)
911-
result = dpnp.get_result_array(result, out)
907+
usm_a = dpnp.get_usm_ndarray(a)
908+
usm_res = dpt.std(usm_a, axis=axis, correction=ddof, keepdims=keepdims)
909+
result = dpnp.get_result_array(usm_res, out)
912910

913911
if dtype is not None and out is None:
914912
result = result.astype(dtype, casting="same_kind")
@@ -1028,11 +1026,9 @@ def var(
10281026

10291027
dpnp.divide(result, cnt, out=result)
10301028
else:
1031-
dpt_array = dpnp.get_usm_ndarray(a)
1032-
result = dpnp_array._create_from_usm_ndarray(
1033-
dpt.var(dpt_array, axis=axis, correction=ddof, keepdims=keepdims)
1034-
)
1035-
result = dpnp.get_result_array(result, out)
1029+
usm_a = dpnp.get_usm_ndarray(a)
1030+
usm_res = dpt.var(usm_a, axis=axis, correction=ddof, keepdims=keepdims)
1031+
result = dpnp.get_result_array(usm_res, out)
10361032

10371033
if out is None and dtype is not None:
10381034
result = result.astype(dtype, casting="same_kind")

dpnp/dpnp_utils/dpnp_utils_reduction.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
import dpnp
28-
from dpnp.dpnp_array import dpnp_array
2928

3029
__all__ = ["dpnp_wrap_reduction_call"]
3130

@@ -53,5 +52,4 @@ def dpnp_wrap_reduction_call(
5352

5453
kwargs["out"] = usm_out
5554
res_usm = _reduction_fn(*args, **kwargs)
56-
res = dpnp_array._create_from_usm_ndarray(res_usm)
57-
return dpnp.get_result_array(res, input_out, casting="unsafe")
55+
return dpnp.get_result_array(res_usm, input_out, casting="unsafe")

0 commit comments

Comments
 (0)