Skip to content

Commit 23f3232

Browse files
committed
Abstractions for read-only arrays
1 parent ee25aae commit 23f3232

File tree

5 files changed

+449
-7
lines changed

5 files changed

+449
-7
lines changed

array_api_compat/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy and CuPy that is compatible with the
5-
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6-
https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is
5+
compatible with the Array API standard https://data-apis.org/array-api/latest/.
6+
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

Lines changed: 254 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
"""
88
from __future__ import annotations
99

10+
import operator
1011
from typing import TYPE_CHECKING
1112

1213
if TYPE_CHECKING:
13-
from typing import Optional, Union, Any
14+
from typing import Callable, Literal, Optional, Union, Any
1415
from ._typing import Array, Device
1516

1617
import sys
@@ -91,7 +92,7 @@ def is_cupy_array(x):
9192
import cupy as cp
9293

9394
# TODO: Should we reject ndarray subclasses?
94-
return isinstance(x, (cp.ndarray, cp.generic))
95+
return isinstance(x, cp.ndarray)
9596

9697
def is_torch_array(x):
9798
"""
@@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787788
return x
788789
return x.to_device(device, stream=stream)
789790

791+
790792
def size(x):
791793
"""
792794
Return the total number of elements of x.
@@ -801,6 +803,253 @@ def size(x):
801803
return None
802804
return math.prod(x.shape)
803805

806+
807+
def is_writeable_array(x) -> bool:
808+
"""
809+
Return False if ``x.__setitem__`` is expected to raise; True otherwise
810+
"""
811+
if is_numpy_array(x):
812+
return x.flags.writeable
813+
if is_jax_array(x) or is_pydata_sparse_array(x):
814+
return False
815+
return True
816+
817+
818+
def _is_fancy_index(idx) -> bool:
819+
if not isinstance(idx, tuple):
820+
idx = (idx,)
821+
return any(
822+
isinstance(i, (list, tuple)) or is_array_api_obj(i)
823+
for i in idx
824+
)
825+
826+
827+
_undef = object()
828+
829+
830+
class at:
831+
"""
832+
Update operations for read-only arrays.
833+
834+
This implements ``jax.numpy.ndarray.at`` for all backends.
835+
836+
Keyword arguments are passed verbatim to backends that support the `ndarray.at`
837+
method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
838+
ignored for backends that don't support them.
839+
840+
Additionally, this introduces support for the `copy` keyword for all backends:
841+
842+
None
843+
The array parameter *may* be modified in place if it is possible and beneficial
844+
for performance. You should not reuse it after calling this function.
845+
True
846+
Ensure that the inputs are not modified. This is the default.
847+
False
848+
Raise ValueError if a copy cannot be avoided.
849+
850+
Examples
851+
--------
852+
Given either of these equivalent expressions::
853+
854+
x = at(x)[1].add(2, copy=None)
855+
x = at(x, 1).add(2, copy=None)
856+
857+
If x is a JAX array, they are the same as::
858+
859+
x = x.at[1].add(2)
860+
861+
If x is a read-only numpy array, they are the same as::
862+
863+
x = x.copy()
864+
x[1] += 2
865+
866+
Otherwise, they are the same as::
867+
868+
x[1] += 2
869+
870+
Warning
871+
-------
872+
When you use copy=None, you should always immediately overwrite
873+
the parameter array::
874+
875+
x = at(x, 0).set(2, copy=None)
876+
877+
The anti-pattern below must be avoided, as it will result in different behaviour
878+
on read-only versus writeable arrays::
879+
880+
x = xp.asarray([0, 0, 0])
881+
y = at(x, 0).set(2, copy=None)
882+
z = at(x, 1).set(3, copy=None)
883+
884+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
885+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
886+
887+
Warning
888+
-------
889+
The behaviour of update methods when the index is an array of integers which
890+
contains multiple occurrences of the same index is undefined;
891+
e.g. ``at(x, [0, 0]).set(2)``
892+
893+
Note
894+
----
895+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
896+
897+
See Also
898+
--------
899+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
900+
"""
901+
902+
__slots__ = ("x", "idx")
903+
904+
def __init__(self, x, idx=_undef, /):
905+
self.x = x
906+
self.idx = idx
907+
908+
def __getitem__(self, idx):
909+
"""
910+
Allow for the alternate syntax ``at(x)[start:stop:step]``,
911+
which looks prettier than ``at(x, slice(start, stop, step))``
912+
and feels more intuitive coming from the JAX documentation.
913+
"""
914+
if self.idx is not _undef:
915+
raise ValueError("Index has already been set")
916+
self.idx = idx
917+
return self
918+
919+
def _common(
920+
self,
921+
at_op: str,
922+
y=_undef,
923+
copy: bool | None | Literal["_force_false"] = True,
924+
**kwargs,
925+
):
926+
"""Perform common prepocessing.
927+
928+
Returns
929+
-------
930+
If the operation can be resolved by at[], (return value, None)
931+
Otherwise, (None, preprocessed x)
932+
"""
933+
if self.idx is _undef:
934+
raise TypeError(
935+
"Index has not been set.\n"
936+
"Usage: either\n"
937+
" at(x, idx).set(value)\n"
938+
"or\n"
939+
" at(x)[idx].set(value)\n"
940+
"(same for all other methods)."
941+
)
942+
943+
x = self.x
944+
945+
if copy is False:
946+
if not is_writeable_array(x) or is_dask_array(x):
947+
raise ValueError("Cannot modify parameter in place")
948+
elif copy is None:
949+
copy = not is_writeable_array(x)
950+
elif copy == "_force_false":
951+
copy = False
952+
elif copy is not True:
953+
raise ValueError(f"Invalid value for copy: {copy!r}")
954+
955+
if is_jax_array(x):
956+
# Use JAX's at[]
957+
at_ = x.at[self.idx]
958+
args = (y,) if y is not _undef else ()
959+
return getattr(at_, at_op)(*args, **kwargs), None
960+
961+
# Emulate at[] behaviour for non-JAX arrays
962+
if copy:
963+
# FIXME We blindly expect the output of x.copy() to be always writeable.
964+
# This holds true for read-only numpy arrays, but not necessarily for
965+
# other backends.
966+
xp = array_namespace(x)
967+
x = xp.asarray(x, copy=True)
968+
969+
return None, x
970+
971+
def get(self, **kwargs):
972+
"""
973+
Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
974+
that the output is either a copy or a view; it also allows passing
975+
keyword arguments to the backend.
976+
"""
977+
# __getitem__ with a fancy index always returns a copy.
978+
# Avoid an unnecessary double copy.
979+
# If copy is forced to False, raise.
980+
if _is_fancy_index(self.idx):
981+
if kwargs.get("copy", True) is False:
982+
raise TypeError(
983+
"Indexing a numpy array with a fancy index always "
984+
"results in a copy"
985+
)
986+
# Skip copy inside _common, even if array is not writeable
987+
kwargs["copy"] = "_force_false"
988+
989+
res, x = self._common("get", **kwargs)
990+
if res is not None:
991+
return res
992+
return x[self.idx]
993+
994+
def set(self, y, /, **kwargs):
995+
"""Apply ``x[idx] = y`` and return the update array"""
996+
res, x = self._common("set", y, **kwargs)
997+
if res is not None:
998+
return res
999+
x[self.idx] = y
1000+
return x
1001+
1002+
def _iop(
1003+
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
1004+
):
1005+
"""x[idx] += y or equivalent in-place operation on a subset of x
1006+
1007+
which is the same as saying
1008+
x[idx] = x[idx] + y
1009+
Note that this is not the same as
1010+
operator.iadd(x[idx], y)
1011+
Consider for example when x is a numpy array and idx is a fancy index, which
1012+
triggers a deep copy on __getitem__.
1013+
"""
1014+
res, x = self._common(at_op, y, **kwargs)
1015+
if res is not None:
1016+
return res
1017+
x[self.idx] = elwise_op(x[self.idx], y)
1018+
return x
1019+
1020+
def add(self, y, /, **kwargs):
1021+
"""Apply ``x[idx] += y`` and return the updated array"""
1022+
return self._iop("add", operator.add, y, **kwargs)
1023+
1024+
def subtract(self, y, /, **kwargs):
1025+
"""Apply ``x[idx] -= y`` and return the updated array"""
1026+
return self._iop("subtract", operator.sub, y, **kwargs)
1027+
1028+
def multiply(self, y, /, **kwargs):
1029+
"""Apply ``x[idx] *= y`` and return the updated array"""
1030+
return self._iop("multiply", operator.mul, y, **kwargs)
1031+
1032+
def divide(self, y, /, **kwargs):
1033+
"""Apply ``x[idx] /= y`` and return the updated array"""
1034+
return self._iop("divide", operator.truediv, y, **kwargs)
1035+
1036+
def power(self, y, /, **kwargs):
1037+
"""Apply ``x[idx] **= y`` and return the updated array"""
1038+
return self._iop("power", operator.pow, y, **kwargs)
1039+
1040+
def min(self, y, /, **kwargs):
1041+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
1042+
xp = array_namespace(self.x)
1043+
y = xp.asarray(y)
1044+
return self._iop("min", xp.minimum, y, **kwargs)
1045+
1046+
def max(self, y, /, **kwargs):
1047+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
1048+
xp = array_namespace(self.x)
1049+
y = xp.asarray(y)
1050+
return self._iop("max", xp.maximum, y, **kwargs)
1051+
1052+
8041053
__all__ = [
8051054
"array_namespace",
8061055
"device",
@@ -821,8 +1070,10 @@ def size(x):
8211070
"is_ndonnx_namespace",
8221071
"is_pydata_sparse_array",
8231072
"is_pydata_sparse_namespace",
1073+
"is_writeable_array",
8241074
"size",
8251075
"to_device",
1076+
"at",
8261077
]
8271078

828-
_all_ignore = ['sys', 'math', 'inspect', 'warnings']
1079+
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']

docs/helper-functions.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ instead, which would be wrapped.
3636
.. autofunction:: device
3737
.. autofunction:: to_device
3838
.. autofunction:: size
39+
.. autoclass:: at(array[, index])
40+
:members:
3941

4042
Inspection Helpers
4143
------------------
@@ -51,6 +53,7 @@ yet.
5153
.. autofunction:: is_jax_array
5254
.. autofunction:: is_pydata_sparse_array
5355
.. autofunction:: is_ndonnx_array
56+
.. autofunction:: is_writeable_array
5457
.. autofunction:: is_numpy_namespace
5558
.. autofunction:: is_cupy_namespace
5659
.. autofunction:: is_torch_namespace

0 commit comments

Comments
 (0)