Skip to content

Commit d0e722b

Browse files
add normalize_axis_tuple
1 parent 5e7b424 commit d0e722b

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import numpy
4444
import dpctl.tensor as dpt
45+
from numpy.core.numeric import normalize_axis_tuple
4546
from dpnp.dpnp_algo import *
4647
from dpnp.dpnp_utils import *
4748
from dpnp.dpnp_array import dpnp_array
@@ -445,27 +446,28 @@ def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True):
445446
elif where is not True:
446447
pass
447448
else:
448-
449-
if dtype is None and not numpy.issubdtype(x.dtype, numpy.integer):
449+
if dtype is None and not (dpnp.issubdtype(x.dtype, dpnp.integer)
450+
or dpnp.issubdtype(x.dtype, dpnp.bool_)):
450451
dtype = x.dtype
451452

452453
if axis is None:
453454
if x.size == 0:
454-
return dpnp.array([dpnp.nan], dtype=x.dtype)
455+
return dpnp.array([dpnp.nan], dtype=dtype)
455456
else:
456457
result = dpnp.sum(x, dtype=dtype) / x.size
457458
return result.astype(dtype) if dtype else result
458459

459-
if isinstance(axis, int):
460-
axis_ = (axis,)
461-
else:
462-
axis_ = axis
460+
if not isinstance(axis,(tuple,list)):
461+
axis = (axis,)
463462

463+
axis = normalize_axis_tuple(axis, x.ndim, "axis")
464464
res_sum = dpnp.sum(x, axis=axis, dtype=dtype)
465465

466-
for axis_value in axis_:
467-
res_sum /= x.shape[axis_value]
466+
del_ = 1.0
467+
for axis_value in axis:
468+
del_ *= x.shape[axis_value]
468469

470+
res_sum /= del_
469471
return res_sum.astype(dtype) if dtype else res_sum
470472

471473
return call_origin(numpy.mean, x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where)

0 commit comments

Comments
 (0)