-
Notifications
You must be signed in to change notification settings - Fork 12
ENH: apply_where
(migrate lazywhere
from scipy)
#141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
:nosignatures: | ||
:toctree: generated | ||
|
||
apply_where | ||
at | ||
atleast_nd | ||
broadcast_shapes | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,17 +5,23 @@ | |
|
||
import math | ||
import warnings | ||
from collections.abc import Sequence | ||
from collections.abc import Callable, Sequence | ||
from types import ModuleType | ||
from typing import cast | ||
from typing import cast, overload | ||
|
||
from ._at import at | ||
from ._utils import _compat, _helpers | ||
from ._utils._compat import array_namespace, is_jax_array | ||
from ._utils._helpers import asarrays, eager_shape, ndindex | ||
from ._utils._compat import ( | ||
array_namespace, | ||
is_dask_namespace, | ||
is_jax_array, | ||
is_jax_namespace, | ||
) | ||
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex | ||
from ._utils._typing import Array | ||
|
||
__all__ = [ | ||
"apply_where", | ||
"atleast_nd", | ||
"broadcast_shapes", | ||
"cov", | ||
|
@@ -29,6 +35,146 @@ | |
] | ||
|
||
|
||
@overload | ||
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 | ||
cond: Array, | ||
args: Array | tuple[Array, ...], | ||
f1: Callable[..., Array], | ||
f2: Callable[..., Array], | ||
/, | ||
*, | ||
xp: ModuleType | None = None, | ||
) -> Array: ... | ||
|
||
|
||
@overload | ||
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 | ||
cond: Array, | ||
args: Array | tuple[Array, ...], | ||
f1: Callable[..., Array], | ||
/, | ||
*, | ||
fill_value: Array | int | float | complex | bool, | ||
xp: ModuleType | None = None, | ||
) -> Array: ... | ||
|
||
|
||
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02 | ||
cond: Array, | ||
args: Array | tuple[Array, ...], | ||
f1: Callable[..., Array], | ||
f2: Callable[..., Array] | None = None, | ||
/, | ||
*, | ||
fill_value: Array | int | float | complex | bool | None = None, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Run one of two elementwise functions depending on a condition. | ||
|
||
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise | ||
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``. | ||
|
||
Parameters | ||
---------- | ||
cond : array | ||
The condition, expressed as a boolean array. | ||
args : Array or tuple of Arrays | ||
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`. | ||
f1 : callable | ||
Elementwise function of `args`, returning a single array. | ||
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``. | ||
f2 : callable, optional | ||
Elementwise function of `args`, returning a single array. | ||
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``. | ||
Mutually exclusive with `fill_value`. | ||
fill_value : Array or scalar, optional | ||
If provided, value with which to fill output array where `cond` is False. | ||
It does not need to be scalar; it needs however to be broadcastable with | ||
`cond` and `args`. | ||
Mutually exclusive with `f2`. You must provide one or the other. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `cond` and `args`. Default: infer. | ||
|
||
Returns | ||
------- | ||
Array | ||
An array with elements from the output of `f1` where `cond` is True and either | ||
the output of `f2` or `fill_value` where `cond` is False. The returned array has | ||
data type determined by type promotion rules between the output of `f1` and | ||
either `fill_value` or the output of `f2`. | ||
|
||
Notes | ||
----- | ||
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even | ||
when `cond` is False, and `f2` when cond is True. This function evaluates each | ||
function only for their matching condition, if the backend allows for it. | ||
|
||
On Dask, `f1` and `f2` are applied to the individual chunks and should use functions | ||
from the namespace of the chunks. | ||
|
||
Examples | ||
-------- | ||
>>> a = xp.asarray([5, 4, 3]) | ||
>>> b = xp.asarray([0, 2, 2]) | ||
>>> def f(a, b): | ||
... return a // b | ||
>>> apply_where(b != 0, (a, b), f, fill_value=xp.nan) | ||
array([ nan, 2., 1.]) | ||
""" | ||
# Parse and normalize arguments | ||
if (f2 is None) == (fill_value is None): | ||
msg = "Exactly one of `fill_value` or `f2` must be given." | ||
raise TypeError(msg) | ||
args_ = list(args) if isinstance(args, tuple) else [args] | ||
del args | ||
|
||
xp = array_namespace(cond, *args_) if xp is None else xp | ||
crusaderky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if getattr(fill_value, "ndim", 0): | ||
crusaderky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cond, fill_value, *args_ = xp.broadcast_arrays(cond, fill_value, *args_) | ||
else: | ||
cond, *args_ = xp.broadcast_arrays(cond, *args_) | ||
|
||
if is_dask_namespace(xp): | ||
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp) | ||
# map_blocks doesn't descend into tuples of Arrays | ||
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp) | ||
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp) | ||
|
||
|
||
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 | ||
cond: Array, | ||
f1: Callable[..., Array], | ||
f2: Callable[..., Array] | None, | ||
fill_value: Array | int | float | complex | bool | None, | ||
*args: Array, | ||
xp: ModuleType, | ||
) -> Array: | ||
"""Helper of `apply_where`. On Dask, this runs on a single chunk.""" | ||
|
||
if is_jax_namespace(xp): | ||
# jax.jit does not support assignment by boolean mask | ||
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) | ||
Comment on lines
+158
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JAX-on-dask is currently unsupported by Dask. This is here and not much higher above only for this reason. |
||
|
||
temp1 = f1(*(arr[cond] for arr in args)) | ||
|
||
if f2 is None: | ||
dtype = xp.result_type(temp1, fill_value) | ||
if getattr(fill_value, "ndim", 0): | ||
crusaderky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
out = xp.astype(fill_value, dtype, copy=True) | ||
else: | ||
out = xp.full_like(cond, dtype=dtype, fill_value=fill_value) | ||
else: | ||
ncond = ~cond | ||
temp2 = f2(*(arr[ncond] for arr in args)) | ||
dtype = xp.result_type(temp1, temp2) | ||
out = xp.empty_like(cond, dtype=dtype) | ||
out = at(out, ncond).set(temp2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JAX doesn't benefit from this |
||
|
||
return at(out, cond).set(temp1) | ||
|
||
|
||
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: | ||
""" | ||
Recursively expand the dimension of an array to at least `ndim`. | ||
|
@@ -393,12 +539,15 @@ def isclose( | |
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating")) | ||
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) | ||
if a_inexact or b_inexact: | ||
# FIXME: use scipy's lazywhere to suppress warnings on inf | ||
out = xp.where( | ||
# prevent warnings on numpy and dask on inf - inf | ||
mxp = meta_namespace(a, b, xp=xp) | ||
out = apply_where( | ||
xp.isinf(a) | xp.isinf(b), | ||
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)), | ||
(a, b), | ||
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType] | ||
# Note: inf <= inf is True! | ||
xp.abs(a - b) <= (atol + rtol * xp.abs(b)), | ||
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType] | ||
xp=xp, | ||
) | ||
if equal_nan: | ||
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out) | ||
|
Uh oh!
There was an error while loading. Please reload this page.