7
7
"""
8
8
from __future__ import annotations
9
9
10
- import operator
11
10
from typing import TYPE_CHECKING
12
11
13
12
if TYPE_CHECKING :
14
- from typing import Callable , Literal , Optional , Union , Any
13
+ from typing import Optional , Union , Any
15
14
from ._typing import Array , Device
16
15
17
16
import sys
@@ -92,7 +91,7 @@ def is_cupy_array(x):
92
91
import cupy as cp
93
92
94
93
# TODO: Should we reject ndarray subclasses?
95
- return isinstance (x , cp .ndarray )
94
+ return isinstance (x , ( cp .ndarray , cp . generic ) )
96
95
97
96
def is_torch_array (x ):
98
97
"""
@@ -788,7 +787,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788
787
return x
789
788
return x .to_device (device , stream = stream )
790
789
791
-
792
790
def size (x ):
793
791
"""
794
792
Return the total number of elements of x.
@@ -803,253 +801,6 @@ def size(x):
803
801
return None
804
802
return math .prod (x .shape )
805
803
806
-
807
- def is_writeable_array (x ) -> bool :
808
- """
809
- Return False if ``x.__setitem__`` is expected to raise; True otherwise
810
- """
811
- if is_numpy_array (x ):
812
- return x .flags .writeable
813
- if is_jax_array (x ) or is_pydata_sparse_array (x ):
814
- return False
815
- return True
816
-
817
-
818
- def _is_fancy_index (idx ) -> bool :
819
- if not isinstance (idx , tuple ):
820
- idx = (idx ,)
821
- return any (
822
- isinstance (i , (list , tuple )) or is_array_api_obj (i )
823
- for i in idx
824
- )
825
-
826
-
827
- _undef = object ()
828
-
829
-
830
- class at :
831
- """
832
- Update operations for read-only arrays.
833
-
834
- This implements ``jax.numpy.ndarray.at`` for all backends.
835
-
836
- Keyword arguments are passed verbatim to backends that support the `ndarray.at`
837
- method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
838
- ignored for backends that don't support them.
839
-
840
- Additionally, this introduces support for the `copy` keyword for all backends:
841
-
842
- None
843
- The array parameter *may* be modified in place if it is possible and beneficial
844
- for performance. You should not reuse it after calling this function.
845
- True
846
- Ensure that the inputs are not modified. This is the default.
847
- False
848
- Raise ValueError if a copy cannot be avoided.
849
-
850
- Examples
851
- --------
852
- Given either of these equivalent expressions::
853
-
854
- x = at(x)[1].add(2, copy=None)
855
- x = at(x, 1).add(2, copy=None)
856
-
857
- If x is a JAX array, they are the same as::
858
-
859
- x = x.at[1].add(2)
860
-
861
- If x is a read-only numpy array, they are the same as::
862
-
863
- x = x.copy()
864
- x[1] += 2
865
-
866
- Otherwise, they are the same as::
867
-
868
- x[1] += 2
869
-
870
- Warning
871
- -------
872
- When you use copy=None, you should always immediately overwrite
873
- the parameter array::
874
-
875
- x = at(x, 0).set(2, copy=None)
876
-
877
- The anti-pattern below must be avoided, as it will result in different behaviour
878
- on read-only versus writeable arrays::
879
-
880
- x = xp.asarray([0, 0, 0])
881
- y = at(x, 0).set(2, copy=None)
882
- z = at(x, 1).set(3, copy=None)
883
-
884
- In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
885
- when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
886
-
887
- Warning
888
- -------
889
- The behaviour of update methods when the index is an array of integers which
890
- contains multiple occurrences of the same index is undefined;
891
- e.g. ``at(x, [0, 0]).set(2)``
892
-
893
- Note
894
- ----
895
- `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
896
-
897
- See Also
898
- --------
899
- `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
900
- """
901
-
902
- __slots__ = ("x" , "idx" )
903
-
904
- def __init__ (self , x , idx = _undef , / ):
905
- self .x = x
906
- self .idx = idx
907
-
908
- def __getitem__ (self , idx ):
909
- """
910
- Allow for the alternate syntax ``at(x)[start:stop:step]``,
911
- which looks prettier than ``at(x, slice(start, stop, step))``
912
- and feels more intuitive coming from the JAX documentation.
913
- """
914
- if self .idx is not _undef :
915
- raise ValueError ("Index has already been set" )
916
- self .idx = idx
917
- return self
918
-
919
- def _common (
920
- self ,
921
- at_op : str ,
922
- y = _undef ,
923
- copy : bool | None | Literal ["_force_false" ] = True ,
924
- ** kwargs ,
925
- ):
926
- """Perform common prepocessing.
927
-
928
- Returns
929
- -------
930
- If the operation can be resolved by at[], (return value, None)
931
- Otherwise, (None, preprocessed x)
932
- """
933
- if self .idx is _undef :
934
- raise TypeError (
935
- "Index has not been set.\n "
936
- "Usage: either\n "
937
- " at(x, idx).set(value)\n "
938
- "or\n "
939
- " at(x)[idx].set(value)\n "
940
- "(same for all other methods)."
941
- )
942
-
943
- x = self .x
944
-
945
- if copy is False :
946
- if not is_writeable_array (x ) or is_dask_array (x ):
947
- raise ValueError ("Cannot modify parameter in place" )
948
- elif copy is None :
949
- copy = not is_writeable_array (x )
950
- elif copy == "_force_false" :
951
- copy = False
952
- elif copy is not True :
953
- raise ValueError (f"Invalid value for copy: { copy !r} " )
954
-
955
- if is_jax_array (x ):
956
- # Use JAX's at[]
957
- at_ = x .at [self .idx ]
958
- args = (y ,) if y is not _undef else ()
959
- return getattr (at_ , at_op )(* args , ** kwargs ), None
960
-
961
- # Emulate at[] behaviour for non-JAX arrays
962
- if copy :
963
- # FIXME We blindly expect the output of x.copy() to be always writeable.
964
- # This holds true for read-only numpy arrays, but not necessarily for
965
- # other backends.
966
- xp = array_namespace (x )
967
- x = xp .asarray (x , copy = True )
968
-
969
- return None , x
970
-
971
- def get (self , ** kwargs ):
972
- """
973
- Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
974
- that the output is either a copy or a view; it also allows passing
975
- keyword arguments to the backend.
976
- """
977
- # __getitem__ with a fancy index always returns a copy.
978
- # Avoid an unnecessary double copy.
979
- # If copy is forced to False, raise.
980
- if _is_fancy_index (self .idx ):
981
- if kwargs .get ("copy" , True ) is False :
982
- raise TypeError (
983
- "Indexing a numpy array with a fancy index always "
984
- "results in a copy"
985
- )
986
- # Skip copy inside _common, even if array is not writeable
987
- kwargs ["copy" ] = "_force_false"
988
-
989
- res , x = self ._common ("get" , ** kwargs )
990
- if res is not None :
991
- return res
992
- return x [self .idx ]
993
-
994
- def set (self , y , / , ** kwargs ):
995
- """Apply ``x[idx] = y`` and return the update array"""
996
- res , x = self ._common ("set" , y , ** kwargs )
997
- if res is not None :
998
- return res
999
- x [self .idx ] = y
1000
- return x
1001
-
1002
- def _iop (
1003
- self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs
1004
- ):
1005
- """x[idx] += y or equivalent in-place operation on a subset of x
1006
-
1007
- which is the same as saying
1008
- x[idx] = x[idx] + y
1009
- Note that this is not the same as
1010
- operator.iadd(x[idx], y)
1011
- Consider for example when x is a numpy array and idx is a fancy index, which
1012
- triggers a deep copy on __getitem__.
1013
- """
1014
- res , x = self ._common (at_op , y , ** kwargs )
1015
- if res is not None :
1016
- return res
1017
- x [self .idx ] = elwise_op (x [self .idx ], y )
1018
- return x
1019
-
1020
- def add (self , y , / , ** kwargs ):
1021
- """Apply ``x[idx] += y`` and return the updated array"""
1022
- return self ._iop ("add" , operator .add , y , ** kwargs )
1023
-
1024
- def subtract (self , y , / , ** kwargs ):
1025
- """Apply ``x[idx] -= y`` and return the updated array"""
1026
- return self ._iop ("subtract" , operator .sub , y , ** kwargs )
1027
-
1028
- def multiply (self , y , / , ** kwargs ):
1029
- """Apply ``x[idx] *= y`` and return the updated array"""
1030
- return self ._iop ("multiply" , operator .mul , y , ** kwargs )
1031
-
1032
- def divide (self , y , / , ** kwargs ):
1033
- """Apply ``x[idx] /= y`` and return the updated array"""
1034
- return self ._iop ("divide" , operator .truediv , y , ** kwargs )
1035
-
1036
- def power (self , y , / , ** kwargs ):
1037
- """Apply ``x[idx] **= y`` and return the updated array"""
1038
- return self ._iop ("power" , operator .pow , y , ** kwargs )
1039
-
1040
- def min (self , y , / , ** kwargs ):
1041
- """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
1042
- xp = array_namespace (self .x )
1043
- y = xp .asarray (y )
1044
- return self ._iop ("min" , xp .minimum , y , ** kwargs )
1045
-
1046
- def max (self , y , / , ** kwargs ):
1047
- """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
1048
- xp = array_namespace (self .x )
1049
- y = xp .asarray (y )
1050
- return self ._iop ("max" , xp .maximum , y , ** kwargs )
1051
-
1052
-
1053
804
__all__ = [
1054
805
"array_namespace" ,
1055
806
"device" ,
@@ -1070,10 +821,8 @@ def max(self, y, /, **kwargs):
1070
821
"is_ndonnx_namespace" ,
1071
822
"is_pydata_sparse_array" ,
1072
823
"is_pydata_sparse_namespace" ,
1073
- "is_writeable_array" ,
1074
824
"size" ,
1075
825
"to_device" ,
1076
- "at" ,
1077
826
]
1078
827
1079
- _all_ignore = ['inspect ' , 'math' , 'operator ' , 'warnings' , 'sys ' ]
828
+ _all_ignore = ['sys ' , 'math' , 'inspect ' , 'warnings' ]
0 commit comments