Skip to content

Commit 40af29a

Browse files
committed
Reuse nan_to_num in nancumprod and nancumsum
1 parent 4c0908b commit 40af29a

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@
6060
]
6161

6262

63+
def _replace_nan_no_mask(a, val):
64+
"""
65+
Replace NaNs in array `a` with `val`.
66+
67+
If `a` is of inexact type, make a copy of `a`, replace NaNs with
68+
the `val` value, and return the copy together. If `a` is not of
69+
inexact type, do nothing and return `a`.
70+
71+
Parameters
72+
----------
73+
a : {dpnp.ndarray, usm_ndarray}
74+
Input array.
75+
val : float
76+
NaN values are set to `val` before doing the operation.
77+
78+
Returns
79+
-------
80+
out : {dpnp.ndarray}
81+
If `a` is of inexact type, return a copy of `a` with the NaNs
82+
replaced by the fill value, otherwise return `a`.
83+
"""
84+
85+
return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf)
86+
87+
6388
def _replace_nan(a, val):
6489
"""
6590
Replace NaNs in array `a` with `val`.
@@ -315,7 +340,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):
315340
316341
"""
317342

318-
a, _ = _replace_nan(a, 1)
343+
a = _replace_nan_no_mask(a, 1.0)
319344
return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)
320345

321346

@@ -385,7 +410,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
385410
386411
"""
387412

388-
a, _ = _replace_nan(a, 0)
413+
a = _replace_nan_no_mask(a, 0.0)
389414
return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out)
390415

391416

@@ -884,7 +909,7 @@ def nanprod(
884909
885910
"""
886911

887-
a = dpnp.nan_to_num(a, nan=1.0, posinf=dpnp.inf, neginf=-dpnp.inf)
912+
a = _replace_nan_no_mask(a, 1.0)
888913
return dpnp.prod(
889914
a,
890915
axis=axis,
@@ -988,7 +1013,7 @@ def nansum(
9881013
9891014
"""
9901015

991-
a = dpnp.nan_to_num(a, nan=0.0, posinf=dpnp.inf, neginf=-dpnp.inf)
1016+
a = _replace_nan_no_mask(a, 0.0)
9921017
return dpnp.sum(
9931018
a,
9941019
axis=axis,

0 commit comments

Comments
 (0)