Skip to content

Commit 244969b

Browse files
committed
WIP at() method
1 parent f9169fb commit 244969b

File tree

8 files changed

+480
-7
lines changed

8 files changed

+480
-7
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ ignore = [
230230
"PLR09", # Too many <...>
231231
"PLR2004", # Magic value used in comparison
232232
"ISC001", # Conflicts with formatter
233+
"PD008", # pandas-use-of-dot-at
233234
]
234235

235236
[tool.ruff.lint.per-file-ignores]

src/array_api_extra/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3.dev0"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

src/array_api_extra/_funcs.py

Lines changed: 287 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
import typing
3+
import operator
44
import warnings
5+
from collections.abc import Callable
6+
from typing import Any
57

68
if typing.TYPE_CHECKING:
79
from ._lib._typing import Array, ModuleType
810

911
from ._lib import _utils
10-
from ._lib._compat import array_namespace
12+
from ._lib._compat import (
13+
array_namespace,
14+
is_array_api_obj,
15+
is_dask_array,
16+
is_writeable_array,
17+
)
18+
from ._lib._typing import Array, ModuleType
1119

1220
__all__ = [
21+
"at",
1322
"atleast_nd",
1423
"cov",
1524
"create_diagonal",
@@ -548,3 +557,279 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
548557
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
549558
)
550559
return xp.sin(y) / y
560+
561+
562+
_undef = object()
563+
564+
565+
class at: # noqa: N801
566+
"""
567+
Update operations for read-only arrays.
568+
569+
This implements ``jax.numpy.ndarray.at`` for all backends.
570+
571+
Parameters
572+
----------
573+
x : array
574+
Input array.
575+
idx : index, optional
576+
You may use two alternate syntaxes::
577+
578+
at(x, idx).set(value) # or get(), add(), etc.
579+
at(x)[idx].set(value)
580+
581+
copy : bool, optional
582+
True (default)
583+
Ensure that the inputs are not modified.
584+
False
585+
Ensure that the update operation writes back to the input.
586+
Raise ValueError if a copy cannot be avoided.
587+
None
588+
The array parameter *may* be modified in place if it is possible and
589+
beneficial for performance.
590+
You should not reuse it after calling this function.
591+
xp : array_namespace, optional
592+
The standard-compatible namespace for `x`. Default: infer
593+
594+
**kwargs:
595+
If the backend supports an `at` method, any additional keyword
596+
arguments are passed to it verbatim; e.g. this allows passing
597+
``indices_are_sorted=True`` to JAX.
598+
599+
Returns
600+
-------
601+
Updated input array.
602+
603+
Examples
604+
--------
605+
Given either of these equivalent expressions::
606+
607+
x = at(x)[1].add(2, copy=None)
608+
x = at(x, 1).add(2, copy=None)
609+
610+
If x is a JAX array, they are the same as::
611+
612+
x = x.at[1].add(2)
613+
614+
If x is a read-only numpy array, they are the same as::
615+
616+
x = x.copy()
617+
x[1] += 2
618+
619+
Otherwise, they are the same as::
620+
621+
x[1] += 2
622+
623+
Warning
624+
-------
625+
When you use copy=None, you should always immediately overwrite
626+
the parameter array::
627+
628+
x = at(x, 0).set(2, copy=None)
629+
630+
The anti-pattern below must be avoided, as it will result in different behaviour
631+
on read-only versus writeable arrays::
632+
633+
x = xp.asarray([0, 0, 0])
634+
y = at(x, 0).set(2, copy=None)
635+
z = at(x, 1).set(3, copy=None)
636+
637+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
638+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
639+
640+
Warning
641+
-------
642+
The array API standard does not support integer array indices.
643+
The behaviour of update methods when the index is an array of integers
644+
is undefined; this is particularly true when the index contains multiple
645+
occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``.
646+
647+
Note
648+
----
649+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
650+
651+
See Also
652+
--------
653+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
654+
"""
655+
656+
x: Array
657+
idx: Any
658+
__slots__ = ("idx", "x")
659+
660+
def __init__(self, x: Array, idx: Any = _undef, /):
661+
self.x = x
662+
self.idx = idx
663+
664+
def __getitem__(self, idx: Any) -> Any:
665+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
666+
which looks prettier than ``at(x, slice(start, stop, step))``
667+
and feels more intuitive coming from the JAX documentation.
668+
"""
669+
if self.idx is not _undef:
670+
msg = "Index has already been set"
671+
raise ValueError(msg)
672+
self.idx = idx
673+
return self
674+
675+
def _common(
676+
self,
677+
at_op: str,
678+
y: Array = _undef,
679+
/,
680+
copy: bool | None = True,
681+
xp: ModuleType | None = None,
682+
_is_update: bool = True,
683+
**kwargs: Any,
684+
) -> tuple[Any, None] | tuple[None, Array]:
685+
"""Perform common prepocessing.
686+
687+
Returns
688+
-------
689+
If the operation can be resolved by at[], (return value, None)
690+
Otherwise, (None, preprocessed x)
691+
"""
692+
if self.idx is _undef:
693+
msg = (
694+
"Index has not been set.\n"
695+
"Usage: either\n"
696+
" at(x, idx).set(value)\n"
697+
"or\n"
698+
" at(x)[idx].set(value)\n"
699+
"(same for all other methods)."
700+
)
701+
raise TypeError(msg)
702+
703+
x = self.x
704+
705+
if copy is True:
706+
writeable = None
707+
elif copy is False:
708+
writeable = is_writeable_array(x)
709+
if not writeable:
710+
msg = "Cannot modify parameter in place"
711+
raise ValueError(msg)
712+
elif copy is None:
713+
writeable = is_writeable_array(x)
714+
copy = _is_update and not writeable
715+
else:
716+
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable]
717+
raise ValueError(msg)
718+
719+
if copy:
720+
try:
721+
at_ = x.at
722+
except AttributeError:
723+
# Emulate at[] behaviour for non-JAX arrays
724+
# with a copy followed by an update
725+
if xp is None:
726+
xp = array_namespace(x)
727+
# Create writeable copy of read-only numpy array
728+
x = xp.asarray(x, copy=True)
729+
if writeable is False:
730+
# A copy of a read-only numpy array is writeable
731+
writeable = None
732+
else:
733+
# Use JAX's at[] or other library that with the same duck-type API
734+
args = (y,) if y is not _undef else ()
735+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
736+
737+
if _is_update:
738+
if writeable is None:
739+
writeable = is_writeable_array(x)
740+
if not writeable:
741+
# sparse crashes here
742+
msg = f"Array {x} has no `at` method and is read-only"
743+
raise ValueError(msg)
744+
745+
return None, x
746+
747+
def get(self, **kwargs: Any) -> Any:
748+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
749+
that the output is either a copy or a view; it also allows passing
750+
keyword arguments to the backend.
751+
"""
752+
if kwargs.get("copy") is False:
753+
if is_array_api_obj(self.idx):
754+
# Boolean index. Note that the array API spec
755+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
756+
# does not allow for list, tuple, and tuples of slices plus one or more
757+
# one-dimensional array indices, although many backends support them.
758+
# So this check will encounter a lot of false negatives in real life,
759+
# which can be caught by testing the user code vs. array-api-strict.
760+
msg = "get() with an array index always returns a copy"
761+
raise ValueError(msg)
762+
if is_dask_array(self.x):
763+
msg = "get() on Dask arrays always returns a copy"
764+
raise ValueError(msg)
765+
766+
res, x = self._common("get", _is_update=False, **kwargs)
767+
if res is not None:
768+
return res
769+
assert x is not None
770+
return x[self.idx]
771+
772+
def set(self, y: Array, /, **kwargs: Any) -> Array:
773+
"""Apply ``x[idx] = y`` and return the update array"""
774+
res, x = self._common("set", y, **kwargs)
775+
if res is not None:
776+
return res
777+
assert x is not None
778+
x[self.idx] = y
779+
return x
780+
781+
def _iop(
782+
self,
783+
at_op: str,
784+
elwise_op: Callable[[Array, Array], Array],
785+
y: Array,
786+
/,
787+
**kwargs: Any,
788+
) -> Array:
789+
"""x[idx] += y or equivalent in-place operation on a subset of x
790+
791+
which is the same as saying
792+
x[idx] = x[idx] + y
793+
Note that this is not the same as
794+
operator.iadd(x[idx], y)
795+
Consider for example when x is a numpy array and idx is a fancy index, which
796+
triggers a deep copy on __getitem__.
797+
"""
798+
res, x = self._common(at_op, y, **kwargs)
799+
if res is not None:
800+
return res
801+
assert x is not None
802+
x[self.idx] = elwise_op(x[self.idx], y)
803+
return x
804+
805+
def add(self, y: Array, /, **kwargs: Any) -> Array:
806+
"""Apply ``x[idx] += y`` and return the updated array"""
807+
return self._iop("add", operator.add, y, **kwargs)
808+
809+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
810+
"""Apply ``x[idx] -= y`` and return the updated array"""
811+
return self._iop("subtract", operator.sub, y, **kwargs)
812+
813+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
814+
"""Apply ``x[idx] *= y`` and return the updated array"""
815+
return self._iop("multiply", operator.mul, y, **kwargs)
816+
817+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
818+
"""Apply ``x[idx] /= y`` and return the updated array"""
819+
return self._iop("divide", operator.truediv, y, **kwargs)
820+
821+
def power(self, y: Array, /, **kwargs: Any) -> Array:
822+
"""Apply ``x[idx] **= y`` and return the updated array"""
823+
return self._iop("power", operator.pow, y, **kwargs)
824+
825+
def min(self, y: Array, /, **kwargs: Any) -> Array:
826+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
827+
xp = array_namespace(self.x)
828+
y = xp.asarray(y)
829+
return self._iop("min", xp.minimum, y, **kwargs)
830+
831+
def max(self, y: Array, /, **kwargs: Any) -> Array:
832+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
833+
xp = array_namespace(self.x)
834+
y = xp.asarray(y)
835+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,23 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace, # pyright: ignore[reportUnknownVariableType]
88
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
912
)
1013
except ImportError:
1114
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1215
array_namespace, # pyright: ignore[reportUnknownVariableType]
1316
device,
17+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
18+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1420
)
1521

16-
__all__ = [
22+
__all__ = (
1723
"array_namespace",
1824
"device",
19-
]
25+
"is_array_api_obj",
26+
"is_dask_array",
27+
"is_writeable_array",
28+
)

src/array_api_extra/_lib/_compat.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
16+
def is_writeable_array(x: object, /) -> bool: ...

0 commit comments

Comments
 (0)