Skip to content

Commit 8124094

Browse files
committed
implement dpnp.nanmedian
1 parent cd23361 commit 8124094

File tree

8 files changed

+418
-8
lines changed

8 files changed

+418
-8
lines changed

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import warnings
4141

4242
import dpnp
43+
from dpnp.dpnp_utils.dpnp_utils_statistics import dpnp_nanmedian
4344

4445
__all__ = [
4546
"nanargmax",
@@ -48,6 +49,7 @@
4849
"nancumsum",
4950
"nanmax",
5051
"nanmean",
52+
"nanmedian",
5153
"nanmin",
5254
"nanprod",
5355
"nanstd",
@@ -568,6 +570,112 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
568570
return avg
569571

570572

573+
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
574+
"""
575+
Compute the median along the specified axis, while ignoring NaNs.
576+
577+
For full documentation refer to :obj:`numpy.nanmedian`.
578+
579+
Parameters
580+
----------
581+
a : {dpnp.ndarray, usm_ndarray}
582+
Input array.
583+
axis : {None, int, tuple or list of ints}, optional
584+
Axis or axes along which the medians are computed. The default,
585+
``axis=None``, will compute the median along a flattened version of
586+
the array. If a sequence of axes, the array is first flattened along
587+
the given axes, then the median is computed along the resulting
588+
flattened axis.
589+
Default: ``None``.
590+
out : {None, dpnp.ndarray, usm_ndarray}, optional
591+
Alternative output array in which to place the result. It must have
592+
the same shape as the expected output but the type (of the calculated
593+
values) will be cast if necessary.
594+
Default: ``None``.
595+
overwrite_input : bool, optional
596+
If ``True``, then allow use of memory of input array `a` for
597+
calculations. The input array will be modified by the call to
598+
:obj:`dpnp.nanmedian`. This will save memory when you do not need to
599+
preserve the contents of the input array. Treat the input as undefined,
600+
but it will probably be fully or partially sorted.
601+
Default: ``False``.
602+
keepdims : bool, optional
603+
If ``True``, the reduced axes (dimensions) are included in the result
604+
as singleton dimensions, so that the returned array remains
605+
compatible with the input array according to Array Broadcasting
606+
rules. Otherwise, if ``False``, the reduced axes are not included in
607+
the returned array.
608+
Default: ``False``.
609+
610+
Returns
611+
-------
612+
out : dpnp.ndarray
613+
A new array holding the result. If `a` has a floating-point data type,
614+
the returned array will have the same data type as `a`. If `a` has a
615+
boolean or integral data type, the returned array will have the
616+
default floating point data type for the device where input array `a`
617+
is allocated.
618+
619+
See Also
620+
--------
621+
:obj:`dpnp.mean` : Compute the arithmetic mean along the specified axis.
622+
:obj:`dpnp.median` : Compute the median along the specified axis.
623+
:obj:`dpnp.percentile` : Compute the q-th percentile of the data
624+
along the specified axis.
625+
626+
Notes
627+
-----
628+
Given a vector ``V`` of length ``N``, the median of ``V`` is the
629+
middle value of a sorted copy of ``V``, ``V_sorted`` - i.e.,
630+
``V_sorted[(N-1)/2]``, when ``N`` is odd, and the average of the
631+
two middle values of ``V_sorted`` when ``N`` is even.
632+
633+
Examples
634+
--------
635+
>>> import dpnp as np
636+
>>> a = np.array([[10.0, 7, 4], [3, 2, 1]])
637+
>>> a[0, 1] = np.nan
638+
>>> a
639+
array([[10., nan, 4.],
640+
[ 3., 2., 1.]])
641+
>>> np.median(a)
642+
array(nan)
643+
>>> np.nanmedian(a)
644+
array(3.)
645+
646+
>>> np.nanmedian(a, axis=0)
647+
array([6.5, 2., 2.5])
648+
>>> np.nanmedian(a, axis=1)
649+
array([7., 2.])
650+
651+
>>> b = a.copy()
652+
>>> np.nanmedian(b, axis=1, overwrite_input=True)
653+
array([7., 2.])
654+
>>> assert not np.all(a==b)
655+
>>> b = a.copy()
656+
>>> np.nanmedian(b, axis=None, overwrite_input=True)
657+
array(3.)
658+
>>> assert not np.all(a==b)
659+
660+
"""
661+
662+
dpnp.check_supported_arrays_type(a)
663+
if dpnp.issubdtype(a.dtype, dpnp.inexact):
664+
# apply_along_axis in _nanmedian doesn't handle empty arrays well,
665+
# so deal them upfront
666+
if a.size == 0:
667+
return dpnp.nanmean(a, axis, out=out, keepdims=keepdims)
668+
return dpnp_nanmedian(
669+
a,
670+
keepdims=keepdims,
671+
axis=axis,
672+
out=out,
673+
overwrite_input=overwrite_input,
674+
)
675+
676+
return dpnp.median(a, axis, out, overwrite_input, keepdims)
677+
678+
571679
def nanmin(a, axis=None, out=None, keepdims=False, initial=None, where=True):
572680
"""
573681
Return the minimum of an array or minimum along an axis, ignoring any NaNs.

dpnp/dpnp_iface_statistics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
765765
preserve the contents of the input array. Treat the input as undefined,
766766
but it will probably be fully or partially sorted.
767767
Default: ``False``.
768-
keepdims : {None, bool}, optional
768+
keepdims : bool, optional
769769
If ``True``, the reduced axes (dimensions) are included in the result
770770
as singleton dimensions, so that the returned array remains
771771
compatible with the input array according to Array Broadcasting
@@ -775,7 +775,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
775775
776776
Returns
777777
-------
778-
dpnp.median : dpnp.ndarray
778+
out : dpnp.ndarray
779779
A new array holding the result. If `a` has a floating-point data type,
780780
the returned array will have the same data type as `a`. If `a` has a
781781
boolean or integral data type, the returned array will have the
@@ -808,20 +808,20 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
808808
>>> np.median(a, axis=0)
809809
array([6.5, 4.5, 2.5])
810810
>>> np.median(a, axis=1)
811-
array([7., 2.])
811+
array([7., 2.])
812812
>>> np.median(a, axis=(0, 1))
813813
array(3.5)
814814
815815
>>> m = np.median(a, axis=0)
816816
>>> out = np.zeros_like(m)
817817
>>> np.median(a, axis=0, out=m)
818-
array([6.5, 4.5, 2.5])
818+
array([6.5, 4.5, 2.5])
819819
>>> m
820-
array([6.5, 4.5, 2.5])
820+
array([6.5, 4.5, 2.5])
821821
822822
>>> b = a.copy()
823823
>>> np.median(b, axis=1, overwrite_input=True)
824-
array([7., 2.])
824+
array([7., 2.])
825825
>>> assert not np.all(a==b)
826826
>>> b = a.copy()
827827
>>> np.median(b, axis=None, overwrite_input=True)

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,94 @@
2424
# *****************************************************************************
2525

2626

27+
import warnings
28+
29+
from dpctl.tensor._numpy_helper import normalize_axis_tuple
30+
2731
import dpnp
2832
from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device
2933

30-
__all__ = ["dpnp_cov"]
34+
__all__ = ["dpnp_cov", "dpnp_nanmedian"]
35+
36+
37+
def _calc_nanmedian(a, axis=None, out=None, overwrite_input=False):
38+
"""
39+
Private function that doesn't support extended axis or keepdims.
40+
These methods are extended to this function using _ureduce
41+
See nanmedian for parameter usage
42+
43+
"""
44+
if axis is None or a.ndim == 1:
45+
part = dpnp.ravel(a)
46+
if out is None:
47+
return _nanmedian1d(part, overwrite_input)
48+
else:
49+
out[...] = _nanmedian1d(part, overwrite_input)
50+
return out
51+
else:
52+
result = dpnp.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
53+
if out is not None:
54+
out[...] = result
55+
return result
56+
57+
58+
def _nanmedian1d(arr1d, overwrite_input=False):
59+
"""
60+
Private function for rank 1 arrays. Compute the median ignoring NaNs.
61+
See nanmedian for parameter usage
62+
"""
63+
arr1d_parsed, overwrite_input = _remove_nan_1d(
64+
arr1d,
65+
overwrite_input=overwrite_input,
66+
)
67+
68+
if arr1d_parsed.size == 0:
69+
# Ensure that a nan-esque scalar of the appropriate type (and unit)
70+
# is returned for `complexfloating`
71+
return arr1d[-1]
72+
73+
return dpnp.median(arr1d_parsed, overwrite_input=overwrite_input)
74+
75+
76+
def _remove_nan_1d(arr1d, overwrite_input=False):
77+
"""
78+
Equivalent to arr1d[~arr1d.isnan()], but in a different order
79+
80+
Presumably faster as it incurs fewer copies
81+
82+
Parameters
83+
----------
84+
arr1d : ndarray
85+
Array to remove nans from
86+
overwrite_input : bool
87+
True if `arr1d` can be modified in place
88+
89+
Returns
90+
-------
91+
res : ndarray
92+
Array with nan elements removed
93+
overwrite_input : bool
94+
True if `res` can be modified in place, given the constraint on the
95+
input
96+
"""
97+
98+
mask = dpnp.isnan(arr1d)
99+
100+
s = dpnp.nonzero(mask)[0]
101+
if s.size == arr1d.size:
102+
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=6)
103+
return arr1d[:0], True
104+
elif s.size == 0:
105+
return arr1d, overwrite_input
106+
else:
107+
if not overwrite_input:
108+
arr1d = arr1d.copy()
109+
# select non-nans at end of array
110+
enonan = arr1d[-s.size :][~mask[-s.size :]]
111+
# fill nans in beginning of array with non-nans of end
112+
arr1d[s[: enonan.size]] = enonan
113+
114+
return arr1d[: -s.size], True
31115

32116

33117
def dpnp_cov(m, y=None, rowvar=True, dtype=None):
@@ -90,3 +174,53 @@ def _get_2dmin_array(x, dtype):
90174
c *= 1 / fact if fact != 0 else dpnp.nan
91175

92176
return dpnp.squeeze(c)
177+
178+
179+
def dpnp_nanmedian(
180+
a, keepdims=False, axis=None, out=None, overwrite_input=False
181+
):
182+
"""Internal Function."""
183+
184+
nd = a.ndim
185+
if axis is not None:
186+
_axis = normalize_axis_tuple(axis, nd)
187+
188+
if keepdims:
189+
if out is not None:
190+
index_out = tuple(
191+
0 if i in _axis else slice(None) for i in range(nd)
192+
)
193+
out = out[(Ellipsis,) + index_out]
194+
195+
if len(_axis) == 1:
196+
axis = _axis[0]
197+
else:
198+
keep = set(range(nd)) - set(_axis)
199+
nkeep = len(keep)
200+
# swap axis that should not be reduced to front
201+
for i, s in enumerate(sorted(keep)):
202+
a = dpnp.swapaxes(a, i, s)
203+
# merge reduced axis
204+
a = a.reshape(a.shape[:nkeep] + (-1,))
205+
axis = -1
206+
else:
207+
if keepdims:
208+
if out is not None:
209+
index_out = (0,) * nd
210+
out = out[(Ellipsis,) + index_out]
211+
212+
r = _calc_nanmedian(a, axis=axis, out=out, overwrite_input=overwrite_input)
213+
214+
if out is not None:
215+
return out
216+
217+
if keepdims:
218+
if axis is None:
219+
index_r = (dpnp.newaxis,) * nd
220+
else:
221+
index_r = tuple(
222+
dpnp.newaxis if i in _axis else slice(None) for i in range(nd)
223+
)
224+
r = r[(Ellipsis,) + index_r]
225+
226+
return r

0 commit comments

Comments
 (0)