|
60 | 60 | ]
|
61 | 61 |
|
62 | 62 |
|
| 63 | +def _replace_nan_no_mask(a, val): |
| 64 | + """ |
| 65 | + Replace NaNs in array `a` with `val`. |
| 66 | +
|
| 67 | + If `a` is of inexact type, make a copy of `a`, replace NaNs with |
| 68 | + the `val` value, and return the copy together. If `a` is not of |
| 69 | + inexact type, do nothing and return `a`. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + a : {dpnp.ndarray, usm_ndarray} |
| 74 | + Input array. |
| 75 | + val : float |
| 76 | + NaN values are set to `val` before doing the operation. |
| 77 | +
|
| 78 | + Returns |
| 79 | + ------- |
| 80 | + out : {dpnp.ndarray} |
| 81 | + If `a` is of inexact type, return a copy of `a` with the NaNs |
| 82 | + replaced by the fill value, otherwise return `a`. |
| 83 | + """ |
| 84 | + |
| 85 | + return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf) |
| 86 | + |
| 87 | + |
63 | 88 | def _replace_nan(a, val):
|
64 | 89 | """
|
65 | 90 | Replace NaNs in array `a` with `val`.
|
@@ -315,7 +340,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):
|
315 | 340 |
|
316 | 341 | """
|
317 | 342 |
|
318 |
| - a, _ = _replace_nan(a, 1) |
| 343 | + a = _replace_nan_no_mask(a, 1.0) |
319 | 344 | return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)
|
320 | 345 |
|
321 | 346 |
|
@@ -385,7 +410,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
|
385 | 410 |
|
386 | 411 | """
|
387 | 412 |
|
388 |
| - a, _ = _replace_nan(a, 0) |
| 413 | + a = _replace_nan_no_mask(a, 0.0) |
389 | 414 | return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out)
|
390 | 415 |
|
391 | 416 |
|
@@ -884,7 +909,7 @@ def nanprod(
|
884 | 909 |
|
885 | 910 | """
|
886 | 911 |
|
887 |
| - a = dpnp.nan_to_num(a, nan=1.0, posinf=dpnp.inf, neginf=-dpnp.inf) |
| 912 | + a = _replace_nan_no_mask(a, 1.0) |
888 | 913 | return dpnp.prod(
|
889 | 914 | a,
|
890 | 915 | axis=axis,
|
@@ -988,7 +1013,7 @@ def nansum(
|
988 | 1013 |
|
989 | 1014 | """
|
990 | 1015 |
|
991 |
| - a = dpnp.nan_to_num(a, nan=0.0, posinf=dpnp.inf, neginf=-dpnp.inf) |
| 1016 | + a = _replace_nan_no_mask(a, 0.0) |
992 | 1017 | return dpnp.sum(
|
993 | 1018 | a,
|
994 | 1019 | axis=axis,
|
|
0 commit comments