Skip to content

Reuse dpnp.nan_to_num in dpnp.nansum and dpnp.nanprod #2339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions dpnp/dpnp_iface_nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,36 @@
]


def _replace_nan_no_mask(a, val):
"""
Replace NaNs in array `a` with `val`.

If `a` is of inexact type, make a copy of `a`, replace NaNs with
the `val` value, and return the copy. If `a` is not of inexact type,
do nothing and return `a`.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Input array.
val : float
NaN values are set to `val` before doing the operation.

Returns
-------
out : dpnp.ndarray
If `a` is of inexact type, return a copy of `a` with the NaNs
replaced by the fill value, otherwise return `a`.

"""

dpnp.check_supported_arrays_type(a)
if dpnp.issubdtype(a.dtype, dpnp.inexact):
return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf)

return a


def _replace_nan(a, val):
"""
Replace NaNs in array `a` with `val`.
Expand Down Expand Up @@ -107,6 +137,18 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):

For full documentation refer to :obj:`numpy.nanargmax`.

Warning
-------
This function synchronizes in order to test for all-NaN slices in the array.
This may harm performance in some applications. To avoid synchronization,
the user is recommended to filter NaNs themselves and use `dpnp.argmax`
on the filtered array.

Warning
-------
The results cannot be trusted if a slice contains only NaNs
and -Infs.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Expand Down Expand Up @@ -136,8 +178,6 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
values ignoring NaNs. The returned array must have the default array
index data type.
For all-NaN slices ``ValueError`` is raised.
Warning: the results cannot be trusted if a slice contains only NaNs
and -Infs.

Limitations
-----------
Expand Down Expand Up @@ -181,6 +221,18 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):

For full documentation refer to :obj:`numpy.nanargmin`.

Warning
-------
This function synchronizes in order to test for all-NaN slices in the array.
This may harm performance in some applications. To avoid synchronization,
the user is recommended to filter NaNs themselves and use `dpnp.argmax`
on the filtered array.

Warning
-------
The results cannot be trusted if a slice contains only NaNs
and -Infs.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Expand Down Expand Up @@ -210,8 +262,6 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
values ignoring NaNs. The returned array must have the default array
index data type.
For all-NaN slices ``ValueError`` is raised.
Warning: the results cannot be trusted if a slice contains only NaNs
and Infs.

Limitations
-----------
Expand Down Expand Up @@ -315,7 +365,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):

"""

a, _ = _replace_nan(a, 1)
a = _replace_nan_no_mask(a, 1.0)
return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)


Expand Down Expand Up @@ -385,7 +435,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):

"""

a, _ = _replace_nan(a, 0)
a = _replace_nan_no_mask(a, 0.0)
return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out)


Expand Down Expand Up @@ -884,7 +934,7 @@ def nanprod(

"""

a, _ = _replace_nan(a, 1)
a = _replace_nan_no_mask(a, 1.0)
return dpnp.prod(
a,
axis=axis,
Expand Down Expand Up @@ -988,7 +1038,7 @@ def nansum(

"""

a, _ = _replace_nan(a, 0)
a = _replace_nan_no_mask(a, 0.0)
return dpnp.sum(
a,
axis=axis,
Expand Down
Loading