@@ -57,10 +57,10 @@ def _calc_median(a, axis, out=None):
57
57
return res
58
58
59
59
60
- def _calc_nanmedian (a , axis , out = None ):
60
+ def _calc_nanmedian (a , out = None ):
61
61
"""Compute the median of an array along a specified axis, ignoring NaNs."""
62
62
mask = dpnp .isnan (a )
63
- valid_counts = dpnp .sum (~ mask , axis = axis )
63
+ valid_counts = dpnp .sum (~ mask , axis = - 1 )
64
64
if out is None :
65
65
res = dpnp .empty_like (valid_counts , dtype = a .dtype )
66
66
else :
@@ -76,27 +76,19 @@ def _calc_nanmedian(a, axis, out=None):
76
76
)
77
77
res = out
78
78
79
- # Iterate over all indices of the output shape
80
- for idx in dpnp .ndindex (res .shape ):
81
- current_valid_counts = valid_counts [idx ]
79
+ left = (valid_counts - 1 ) // 2
80
+ right = valid_counts // 2
82
81
83
- if current_valid_counts > 0 :
84
- # Extract the corresponding slice from the last axis of `a`
85
- data = a [ idx ][: current_valid_counts ]
86
- left = ( current_valid_counts - 1 ) // 2
87
- right = current_valid_counts // 2
82
+ left_data = dpnp . take_along_axis ( a , left [..., None ], axis = - 1 )
83
+ right_data = dpnp . take_along_axis ( a , right [..., None ], axis = - 1 )
84
+ res = dpnp . where (
85
+ valid_counts [..., None ] > 0 , ( left_data + right_data ) / 2.0 , dpnp . nan
86
+ )
88
87
89
- if left == right :
90
- res [idx ] = data [left ]
91
- else :
92
- res [idx ] = (data [left ] + data [right ]) / 2.0
93
- else :
94
- warnings .warn (
95
- "All-NaN slice encountered" , RuntimeWarning , stacklevel = 6
96
- )
97
- res [idx ] = dpnp .nan
88
+ if mask .all (axis = - 1 ).any ():
89
+ warnings .warn ("All-NaN slice encountered" , RuntimeWarning , stacklevel = 6 )
98
90
99
- return res
91
+ return dpnp . squeeze ( res )
100
92
101
93
102
94
def _flatten_array_along_axes (a , axes_to_flatten , overwrite_input ):
@@ -232,7 +224,8 @@ def dpnp_median(
232
224
233
225
if ignore_nan :
234
226
# sorting puts NaNs at the end
235
- res = _calc_nanmedian (a_sorted , axis = axis , out = out )
227
+ assert axis == - 1
228
+ res = _calc_nanmedian (a_sorted , out = out )
236
229
else :
237
230
# We can't pass keepdims and use it in dpnp.mean and dpnp.any
238
231
# because of the reshape hack that might have been used in
0 commit comments