|
60 | 60 | ]
|
61 | 61 |
|
62 | 62 |
|
| 63 | +def _replace_nan_no_mask(a, val): |
| 64 | + """ |
| 65 | + Replace NaNs in array `a` with `val`. |
| 66 | +
|
| 67 | + If `a` is of inexact type, make a copy of `a`, replace NaNs with |
| 68 | + the `val` value, and return the copy. If `a` is not of inexact type, |
| 69 | + do nothing and return `a`. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + a : {dpnp.ndarray, usm_ndarray} |
| 74 | + Input array. |
| 75 | + val : float |
| 76 | + NaN values are set to `val` before doing the operation. |
| 77 | +
|
| 78 | + Returns |
| 79 | + ------- |
| 80 | + out : dpnp.ndarray |
| 81 | + If `a` is of inexact type, return a copy of `a` with the NaNs |
| 82 | + replaced by the fill value, otherwise return `a`. |
| 83 | +
|
| 84 | + """ |
| 85 | + |
| 86 | + dpnp.check_supported_arrays_type(a) |
| 87 | + if dpnp.issubdtype(a.dtype, dpnp.inexact): |
| 88 | + return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf) |
| 89 | + |
| 90 | + return a |
| 91 | + |
| 92 | + |
63 | 93 | def _replace_nan(a, val):
|
64 | 94 | """
|
65 | 95 | Replace NaNs in array `a` with `val`.
|
@@ -107,6 +137,18 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
|
107 | 137 |
|
108 | 138 | For full documentation refer to :obj:`numpy.nanargmax`.
|
109 | 139 |
|
| 140 | + Warning |
| 141 | + ------- |
| 142 | + This function synchronizes in order to test for all-NaN slices in the array. |
| 143 | + This may harm performance in some applications. To avoid synchronization, |
| 144 | + the user is recommended to filter NaNs themselves and use `dpnp.argmax` |
| 145 | + on the filtered array. |
| 146 | +
|
| 147 | + Warning |
| 148 | + ------- |
| 149 | + The results cannot be trusted if a slice contains only NaNs |
| 150 | + and -Infs. |
| 151 | +
|
110 | 152 | Parameters
|
111 | 153 | ----------
|
112 | 154 | a : {dpnp.ndarray, usm_ndarray}
|
@@ -136,8 +178,6 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
|
136 | 178 | values ignoring NaNs. The returned array must have the default array
|
137 | 179 | index data type.
|
138 | 180 | For all-NaN slices ``ValueError`` is raised.
|
139 |
| - Warning: the results cannot be trusted if a slice contains only NaNs |
140 |
| - and -Infs. |
141 | 181 |
|
142 | 182 | Limitations
|
143 | 183 | -----------
|
@@ -181,6 +221,18 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
|
181 | 221 |
|
182 | 222 | For full documentation refer to :obj:`numpy.nanargmin`.
|
183 | 223 |
|
| 224 | + Warning |
| 225 | + ------- |
| 226 | + This function synchronizes in order to test for all-NaN slices in the array. |
| 227 | + This may harm performance in some applications. To avoid synchronization, |
| 228 | + the user is recommended to filter NaNs themselves and use `dpnp.argmax` |
| 229 | + on the filtered array. |
| 230 | +
|
| 231 | + Warning |
| 232 | + ------- |
| 233 | + The results cannot be trusted if a slice contains only NaNs |
| 234 | + and -Infs. |
| 235 | +
|
184 | 236 | Parameters
|
185 | 237 | ----------
|
186 | 238 | a : {dpnp.ndarray, usm_ndarray}
|
@@ -210,8 +262,6 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
|
210 | 262 | values ignoring NaNs. The returned array must have the default array
|
211 | 263 | index data type.
|
212 | 264 | For all-NaN slices ``ValueError`` is raised.
|
213 |
| - Warning: the results cannot be trusted if a slice contains only NaNs |
214 |
| - and Infs. |
215 | 265 |
|
216 | 266 | Limitations
|
217 | 267 | -----------
|
@@ -315,7 +365,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):
|
315 | 365 |
|
316 | 366 | """
|
317 | 367 |
|
318 |
| - a, _ = _replace_nan(a, 1) |
| 368 | + a = _replace_nan_no_mask(a, 1.0) |
319 | 369 | return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)
|
320 | 370 |
|
321 | 371 |
|
@@ -385,7 +435,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
|
385 | 435 |
|
386 | 436 | """
|
387 | 437 |
|
388 |
| - a, _ = _replace_nan(a, 0) |
| 438 | + a = _replace_nan_no_mask(a, 0.0) |
389 | 439 | return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out)
|
390 | 440 |
|
391 | 441 |
|
@@ -884,7 +934,7 @@ def nanprod(
|
884 | 934 |
|
885 | 935 | """
|
886 | 936 |
|
887 |
| - a, _ = _replace_nan(a, 1) |
| 937 | + a = _replace_nan_no_mask(a, 1.0) |
888 | 938 | return dpnp.prod(
|
889 | 939 | a,
|
890 | 940 | axis=axis,
|
@@ -988,7 +1038,7 @@ def nansum(
|
988 | 1038 |
|
989 | 1039 | """
|
990 | 1040 |
|
991 |
| - a, _ = _replace_nan(a, 0) |
| 1041 | + a = _replace_nan_no_mask(a, 0.0) |
992 | 1042 | return dpnp.sum(
|
993 | 1043 | a,
|
994 | 1044 | axis=axis,
|
|
0 commit comments