Skip to content

Commit 21be8f4

Browse files
authored
Merge branch 'master' into nd-support-to-trim_zero
2 parents ee19ab4 + cabc0d7 commit 21be8f4

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def _calc_median(a, axis, out=None):
5757
return res
5858

5959

60-
def _calc_nanmedian(a, axis, out=None):
60+
def _calc_nanmedian(a, out=None):
6161
"""Compute the median of an array along a specified axis, ignoring NaNs."""
6262
mask = dpnp.isnan(a)
63-
valid_counts = dpnp.sum(~mask, axis=axis)
63+
valid_counts = dpnp.sum(~mask, axis=-1)
6464
if out is None:
6565
res = dpnp.empty_like(valid_counts, dtype=a.dtype)
6666
else:
@@ -76,27 +76,19 @@ def _calc_nanmedian(a, axis, out=None):
7676
)
7777
res = out
7878

79-
# Iterate over all indices of the output shape
80-
for idx in dpnp.ndindex(res.shape):
81-
current_valid_counts = valid_counts[idx]
79+
left = (valid_counts - 1) // 2
80+
right = valid_counts // 2
8281

83-
if current_valid_counts > 0:
84-
# Extract the corresponding slice from the last axis of `a`
85-
data = a[idx][:current_valid_counts]
86-
left = (current_valid_counts - 1) // 2
87-
right = current_valid_counts // 2
82+
left_data = dpnp.take_along_axis(a, left[..., None], axis=-1)
83+
right_data = dpnp.take_along_axis(a, right[..., None], axis=-1)
84+
res = dpnp.where(
85+
valid_counts[..., None] > 0, (left_data + right_data) / 2.0, dpnp.nan
86+
)
8887

89-
if left == right:
90-
res[idx] = data[left]
91-
else:
92-
res[idx] = (data[left] + data[right]) / 2.0
93-
else:
94-
warnings.warn(
95-
"All-NaN slice encountered", RuntimeWarning, stacklevel=6
96-
)
97-
res[idx] = dpnp.nan
88+
if mask.all(axis=-1).any():
89+
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=6)
9890

99-
return res
91+
return dpnp.squeeze(res)
10092

10193

10294
def _flatten_array_along_axes(a, axes_to_flatten, overwrite_input):
@@ -232,7 +224,8 @@ def dpnp_median(
232224

233225
if ignore_nan:
234226
# sorting puts NaNs at the end
235-
res = _calc_nanmedian(a_sorted, axis=axis, out=out)
227+
assert axis == -1
228+
res = _calc_nanmedian(a_sorted, out=out)
236229
else:
237230
# We can't pass keepdims and use it in dpnp.mean and dpnp.any
238231
# because of the reshape hack that might have been used in

0 commit comments

Comments
 (0)