Skip to content

Commit a06012c

Browse files
committed
Revert "Abstractions for read-only arrays"
This reverts commit 23f3232.
1 parent 23f3232 commit a06012c

File tree

5 files changed

+7
-449
lines changed

5 files changed

+7
-449
lines changed

array_api_compat/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is
5-
compatible with the Array API standard https://data-apis.org/array-api/latest/.
6-
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy and CuPy that is compatible with the
5+
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6+
https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

Lines changed: 3 additions & 254 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
"""
88
from __future__ import annotations
99

10-
import operator
1110
from typing import TYPE_CHECKING
1211

1312
if TYPE_CHECKING:
14-
from typing import Callable, Literal, Optional, Union, Any
13+
from typing import Optional, Union, Any
1514
from ._typing import Array, Device
1615

1716
import sys
@@ -92,7 +91,7 @@ def is_cupy_array(x):
9291
import cupy as cp
9392

9493
# TODO: Should we reject ndarray subclasses?
95-
return isinstance(x, cp.ndarray)
94+
return isinstance(x, (cp.ndarray, cp.generic))
9695

9796
def is_torch_array(x):
9897
"""
@@ -788,7 +787,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788787
return x
789788
return x.to_device(device, stream=stream)
790789

791-
792790
def size(x):
793791
"""
794792
Return the total number of elements of x.
@@ -803,253 +801,6 @@ def size(x):
803801
return None
804802
return math.prod(x.shape)
805803

806-
807-
def is_writeable_array(x) -> bool:
808-
"""
809-
Return False if ``x.__setitem__`` is expected to raise; True otherwise
810-
"""
811-
if is_numpy_array(x):
812-
return x.flags.writeable
813-
if is_jax_array(x) or is_pydata_sparse_array(x):
814-
return False
815-
return True
816-
817-
818-
def _is_fancy_index(idx) -> bool:
819-
if not isinstance(idx, tuple):
820-
idx = (idx,)
821-
return any(
822-
isinstance(i, (list, tuple)) or is_array_api_obj(i)
823-
for i in idx
824-
)
825-
826-
827-
_undef = object()
828-
829-
830-
class at:
831-
"""
832-
Update operations for read-only arrays.
833-
834-
This implements ``jax.numpy.ndarray.at`` for all backends.
835-
836-
Keyword arguments are passed verbatim to backends that support the `ndarray.at`
837-
method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
838-
ignored for backends that don't support them.
839-
840-
Additionally, this introduces support for the `copy` keyword for all backends:
841-
842-
None
843-
The array parameter *may* be modified in place if it is possible and beneficial
844-
for performance. You should not reuse it after calling this function.
845-
True
846-
Ensure that the inputs are not modified. This is the default.
847-
False
848-
Raise ValueError if a copy cannot be avoided.
849-
850-
Examples
851-
--------
852-
Given either of these equivalent expressions::
853-
854-
x = at(x)[1].add(2, copy=None)
855-
x = at(x, 1).add(2, copy=None)
856-
857-
If x is a JAX array, they are the same as::
858-
859-
x = x.at[1].add(2)
860-
861-
If x is a read-only numpy array, they are the same as::
862-
863-
x = x.copy()
864-
x[1] += 2
865-
866-
Otherwise, they are the same as::
867-
868-
x[1] += 2
869-
870-
Warning
871-
-------
872-
When you use copy=None, you should always immediately overwrite
873-
the parameter array::
874-
875-
x = at(x, 0).set(2, copy=None)
876-
877-
The anti-pattern below must be avoided, as it will result in different behaviour
878-
on read-only versus writeable arrays::
879-
880-
x = xp.asarray([0, 0, 0])
881-
y = at(x, 0).set(2, copy=None)
882-
z = at(x, 1).set(3, copy=None)
883-
884-
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
885-
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
886-
887-
Warning
888-
-------
889-
The behaviour of update methods when the index is an array of integers which
890-
contains multiple occurrences of the same index is undefined;
891-
e.g. ``at(x, [0, 0]).set(2)``
892-
893-
Note
894-
----
895-
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
896-
897-
See Also
898-
--------
899-
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
900-
"""
901-
902-
__slots__ = ("x", "idx")
903-
904-
def __init__(self, x, idx=_undef, /):
905-
self.x = x
906-
self.idx = idx
907-
908-
def __getitem__(self, idx):
909-
"""
910-
Allow for the alternate syntax ``at(x)[start:stop:step]``,
911-
which looks prettier than ``at(x, slice(start, stop, step))``
912-
and feels more intuitive coming from the JAX documentation.
913-
"""
914-
if self.idx is not _undef:
915-
raise ValueError("Index has already been set")
916-
self.idx = idx
917-
return self
918-
919-
def _common(
920-
self,
921-
at_op: str,
922-
y=_undef,
923-
copy: bool | None | Literal["_force_false"] = True,
924-
**kwargs,
925-
):
926-
"""Perform common prepocessing.
927-
928-
Returns
929-
-------
930-
If the operation can be resolved by at[], (return value, None)
931-
Otherwise, (None, preprocessed x)
932-
"""
933-
if self.idx is _undef:
934-
raise TypeError(
935-
"Index has not been set.\n"
936-
"Usage: either\n"
937-
" at(x, idx).set(value)\n"
938-
"or\n"
939-
" at(x)[idx].set(value)\n"
940-
"(same for all other methods)."
941-
)
942-
943-
x = self.x
944-
945-
if copy is False:
946-
if not is_writeable_array(x) or is_dask_array(x):
947-
raise ValueError("Cannot modify parameter in place")
948-
elif copy is None:
949-
copy = not is_writeable_array(x)
950-
elif copy == "_force_false":
951-
copy = False
952-
elif copy is not True:
953-
raise ValueError(f"Invalid value for copy: {copy!r}")
954-
955-
if is_jax_array(x):
956-
# Use JAX's at[]
957-
at_ = x.at[self.idx]
958-
args = (y,) if y is not _undef else ()
959-
return getattr(at_, at_op)(*args, **kwargs), None
960-
961-
# Emulate at[] behaviour for non-JAX arrays
962-
if copy:
963-
# FIXME We blindly expect the output of x.copy() to be always writeable.
964-
# This holds true for read-only numpy arrays, but not necessarily for
965-
# other backends.
966-
xp = array_namespace(x)
967-
x = xp.asarray(x, copy=True)
968-
969-
return None, x
970-
971-
def get(self, **kwargs):
972-
"""
973-
Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
974-
that the output is either a copy or a view; it also allows passing
975-
keyword arguments to the backend.
976-
"""
977-
# __getitem__ with a fancy index always returns a copy.
978-
# Avoid an unnecessary double copy.
979-
# If copy is forced to False, raise.
980-
if _is_fancy_index(self.idx):
981-
if kwargs.get("copy", True) is False:
982-
raise TypeError(
983-
"Indexing a numpy array with a fancy index always "
984-
"results in a copy"
985-
)
986-
# Skip copy inside _common, even if array is not writeable
987-
kwargs["copy"] = "_force_false"
988-
989-
res, x = self._common("get", **kwargs)
990-
if res is not None:
991-
return res
992-
return x[self.idx]
993-
994-
def set(self, y, /, **kwargs):
995-
"""Apply ``x[idx] = y`` and return the update array"""
996-
res, x = self._common("set", y, **kwargs)
997-
if res is not None:
998-
return res
999-
x[self.idx] = y
1000-
return x
1001-
1002-
def _iop(
1003-
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
1004-
):
1005-
"""x[idx] += y or equivalent in-place operation on a subset of x
1006-
1007-
which is the same as saying
1008-
x[idx] = x[idx] + y
1009-
Note that this is not the same as
1010-
operator.iadd(x[idx], y)
1011-
Consider for example when x is a numpy array and idx is a fancy index, which
1012-
triggers a deep copy on __getitem__.
1013-
"""
1014-
res, x = self._common(at_op, y, **kwargs)
1015-
if res is not None:
1016-
return res
1017-
x[self.idx] = elwise_op(x[self.idx], y)
1018-
return x
1019-
1020-
def add(self, y, /, **kwargs):
1021-
"""Apply ``x[idx] += y`` and return the updated array"""
1022-
return self._iop("add", operator.add, y, **kwargs)
1023-
1024-
def subtract(self, y, /, **kwargs):
1025-
"""Apply ``x[idx] -= y`` and return the updated array"""
1026-
return self._iop("subtract", operator.sub, y, **kwargs)
1027-
1028-
def multiply(self, y, /, **kwargs):
1029-
"""Apply ``x[idx] *= y`` and return the updated array"""
1030-
return self._iop("multiply", operator.mul, y, **kwargs)
1031-
1032-
def divide(self, y, /, **kwargs):
1033-
"""Apply ``x[idx] /= y`` and return the updated array"""
1034-
return self._iop("divide", operator.truediv, y, **kwargs)
1035-
1036-
def power(self, y, /, **kwargs):
1037-
"""Apply ``x[idx] **= y`` and return the updated array"""
1038-
return self._iop("power", operator.pow, y, **kwargs)
1039-
1040-
def min(self, y, /, **kwargs):
1041-
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
1042-
xp = array_namespace(self.x)
1043-
y = xp.asarray(y)
1044-
return self._iop("min", xp.minimum, y, **kwargs)
1045-
1046-
def max(self, y, /, **kwargs):
1047-
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
1048-
xp = array_namespace(self.x)
1049-
y = xp.asarray(y)
1050-
return self._iop("max", xp.maximum, y, **kwargs)
1051-
1052-
1053804
__all__ = [
1054805
"array_namespace",
1055806
"device",
@@ -1070,10 +821,8 @@ def max(self, y, /, **kwargs):
1070821
"is_ndonnx_namespace",
1071822
"is_pydata_sparse_array",
1072823
"is_pydata_sparse_namespace",
1073-
"is_writeable_array",
1074824
"size",
1075825
"to_device",
1076-
"at",
1077826
]
1078827

1079-
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']
828+
_all_ignore = ['sys', 'math', 'inspect', 'warnings']

docs/helper-functions.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ instead, which would be wrapped.
3636
.. autofunction:: device
3737
.. autofunction:: to_device
3838
.. autofunction:: size
39-
.. autoclass:: at(array[, index])
40-
:members:
4139

4240
Inspection Helpers
4341
------------------
@@ -53,7 +51,6 @@ yet.
5351
.. autofunction:: is_jax_array
5452
.. autofunction:: is_pydata_sparse_array
5553
.. autofunction:: is_ndonnx_array
56-
.. autofunction:: is_writeable_array
5754
.. autofunction:: is_numpy_namespace
5855
.. autofunction:: is_cupy_namespace
5956
.. autofunction:: is_torch_namespace

0 commit comments

Comments
 (0)