Skip to content

Commit 0012e7a

Browse files
committed
Updated dpnp.mean() per review comment
1 parent b6e8b05 commit 0012e7a

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,12 @@ def mean(a, /, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
614614

615615
dpnp.check_limitations(where=where)
616616

617-
dpt_array = dpnp.get_usm_ndarray(a)
618-
result = dpnp_array._create_from_usm_ndarray(
619-
dpt.mean(dpt_array, axis=axis, keepdims=keepdims)
620-
)
621-
result = result.astype(dtype) if dtype is not None else result
617+
usm_a = dpnp.get_usm_ndarray(a)
618+
usm_res = dpt.mean(usm_a, axis=axis, keepdims=keepdims)
619+
if dtype is not None:
620+
usm_res = dpt.astype(usm_res, dtype)
622621

623-
return dpnp.get_result_array(result, out, casting="same_kind")
622+
return dpnp.get_result_array(usm_res, out, casting="same_kind")
624623

625624

626625
def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False):

0 commit comments

Comments
 (0)