Skip to content

Commit 14274d8

Browse files
authored
Merge pull request #2339 from IntelPython/reuse-nan-to-num-nan-fns
Reuse `dpnp.nan_to_num` in `dpnp.nansum` and `dpnp.nanprod`
2 parents 96e723d + 1995cd5 commit 14274d8

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,36 @@
6060
]
6161

6262

63+
def _replace_nan_no_mask(a, val):
64+
"""
65+
Replace NaNs in array `a` with `val`.
66+
67+
If `a` is of inexact type, make a copy of `a`, replace NaNs with
68+
the `val` value, and return the copy. If `a` is not of inexact type,
69+
do nothing and return `a`.
70+
71+
Parameters
72+
----------
73+
a : {dpnp.ndarray, usm_ndarray}
74+
Input array.
75+
val : float
76+
NaN values are set to `val` before doing the operation.
77+
78+
Returns
79+
-------
80+
out : dpnp.ndarray
81+
If `a` is of inexact type, return a copy of `a` with the NaNs
82+
replaced by the fill value, otherwise return `a`.
83+
84+
"""
85+
86+
dpnp.check_supported_arrays_type(a)
87+
if dpnp.issubdtype(a.dtype, dpnp.inexact):
88+
return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf)
89+
90+
return a
91+
92+
6393
def _replace_nan(a, val):
6494
"""
6595
Replace NaNs in array `a` with `val`.
@@ -107,6 +137,18 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
107137
108138
For full documentation refer to :obj:`numpy.nanargmax`.
109139
140+
Warning
141+
-------
142+
This function synchronizes in order to test for all-NaN slices in the array.
143+
This may harm performance in some applications. To avoid synchronization,
144+
the user is recommended to filter NaNs themselves and use `dpnp.argmax`
145+
on the filtered array.
146+
147+
Warning
148+
-------
149+
The results cannot be trusted if a slice contains only NaNs
150+
and -Infs.
151+
110152
Parameters
111153
----------
112154
a : {dpnp.ndarray, usm_ndarray}
@@ -136,8 +178,6 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
136178
values ignoring NaNs. The returned array must have the default array
137179
index data type.
138180
For all-NaN slices ``ValueError`` is raised.
139-
Warning: the results cannot be trusted if a slice contains only NaNs
140-
and -Infs.
141181
142182
Limitations
143183
-----------
@@ -181,6 +221,18 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
181221
182222
For full documentation refer to :obj:`numpy.nanargmin`.
183223
224+
Warning
225+
-------
226+
This function synchronizes in order to test for all-NaN slices in the array.
227+
This may harm performance in some applications. To avoid synchronization,
228+
the user is recommended to filter NaNs themselves and use `dpnp.argmax`
229+
on the filtered array.
230+
231+
Warning
232+
-------
233+
The results cannot be trusted if a slice contains only NaNs
234+
and -Infs.
235+
184236
Parameters
185237
----------
186238
a : {dpnp.ndarray, usm_ndarray}
@@ -210,8 +262,6 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
210262
values ignoring NaNs. The returned array must have the default array
211263
index data type.
212264
For all-NaN slices ``ValueError`` is raised.
213-
Warning: the results cannot be trusted if a slice contains only NaNs
214-
and Infs.
215265
216266
Limitations
217267
-----------
@@ -315,7 +365,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):
315365
316366
"""
317367

318-
a, _ = _replace_nan(a, 1)
368+
a = _replace_nan_no_mask(a, 1.0)
319369
return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)
320370

321371

@@ -385,7 +435,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
385435
386436
"""
387437

388-
a, _ = _replace_nan(a, 0)
438+
a = _replace_nan_no_mask(a, 0.0)
389439
return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out)
390440

391441

@@ -884,7 +934,7 @@ def nanprod(
884934
885935
"""
886936

887-
a, _ = _replace_nan(a, 1)
937+
a = _replace_nan_no_mask(a, 1.0)
888938
return dpnp.prod(
889939
a,
890940
axis=axis,
@@ -988,7 +1038,7 @@ def nansum(
9881038
9891039
"""
9901040

991-
a, _ = _replace_nan(a, 0)
1041+
a = _replace_nan_no_mask(a, 0.0)
9921042
return dpnp.sum(
9931043
a,
9941044
axis=axis,

0 commit comments

Comments
 (0)