|
42 | 42 |
|
43 | 43 | import numpy
|
44 | 44 | import dpctl.tensor as dpt
|
| 45 | +from numpy.core.numeric import normalize_axis_tuple |
45 | 46 | from dpnp.dpnp_algo import *
|
46 | 47 | from dpnp.dpnp_utils import *
|
47 | 48 | from dpnp.dpnp_array import dpnp_array
|
@@ -445,27 +446,28 @@ def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True):
|
445 | 446 | elif where is not True:
|
446 | 447 | pass
|
447 | 448 | else:
|
448 |
| - |
449 |
| - if dtype is None and not numpy.issubdtype(x.dtype, numpy.integer): |
| 449 | + if dtype is None and not (dpnp.issubdtype(x.dtype, dpnp.integer) |
| 450 | + or dpnp.issubdtype(x.dtype, dpnp.bool_)): |
450 | 451 | dtype = x.dtype
|
451 | 452 |
|
452 | 453 | if axis is None:
|
453 | 454 | if x.size == 0:
|
454 |
| - return dpnp.array([dpnp.nan], dtype=x.dtype) |
| 455 | + return dpnp.array([dpnp.nan], dtype=dtype) |
455 | 456 | else:
|
456 | 457 | result = dpnp.sum(x, dtype=dtype) / x.size
|
457 | 458 | return result.astype(dtype) if dtype else result
|
458 | 459 |
|
459 |
| - if isinstance(axis, int): |
460 |
| - axis_ = (axis,) |
461 |
| - else: |
462 |
| - axis_ = axis |
| 460 | + if not isinstance(axis,(tuple,list)): |
| 461 | + axis = (axis,) |
463 | 462 |
|
| 463 | + axis = normalize_axis_tuple(axis, x.ndim, "axis") |
464 | 464 | res_sum = dpnp.sum(x, axis=axis, dtype=dtype)
|
465 | 465 |
|
466 |
| - for axis_value in axis_: |
467 |
| - res_sum /= x.shape[axis_value] |
| 466 | + del_ = 1.0 |
| 467 | + for axis_value in axis: |
| 468 | + del_ *= x.shape[axis_value] |
468 | 469 |
|
| 470 | + res_sum /= del_ |
469 | 471 | return res_sum.astype(dtype) if dtype else res_sum
|
470 | 472 |
|
471 | 473 | return call_origin(numpy.mean, x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where)
|
|
0 commit comments