2
2
3
3
import operator
4
4
import warnings
5
- from collections .abc import Callable
6
- from typing import Any
7
5
8
- if typing .TYPE_CHECKING :
9
- from ._lib ._typing import Array , ModuleType
6
+ # https://github.com/pylint-dev/pylint/issues/10112
7
+ from collections .abc import Callable # pylint: disable=import-error
8
+ from typing import ClassVar
10
9
11
10
from ._lib import _utils
12
11
from ._lib ._compat import (
15
14
is_dask_array ,
16
15
is_writeable_array ,
17
16
)
18
- from ._lib ._typing import Array , ModuleType
17
+ from ._lib ._typing import Array , Index , ModuleType , Untyped
19
18
20
19
__all__ = [
21
20
"at" ,
@@ -562,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
562
561
_undef = object ()
563
562
564
563
565
- class at : # noqa: N801
564
+ class at : # pylint: disable=invalid-name
566
565
"""
567
566
Update operations for read-only arrays.
568
567
@@ -654,14 +653,14 @@ class at: # noqa: N801
654
653
"""
655
654
656
655
x : Array
657
- idx : Any
658
- __slots__ = ("idx" , "x" )
656
+ idx : Index
657
+ __slots__ : ClassVar [ tuple [ str , str ]] = ("idx" , "x" )
659
658
660
- def __init__ (self , x : Array , idx : Any = _undef , / ):
659
+ def __init__ (self , x : Array , idx : Index = _undef , / ):
661
660
self .x = x
662
661
self .idx = idx
663
662
664
- def __getitem__ (self , idx : Any ) -> Any :
663
+ def __getitem__ (self , idx : Index ) -> at :
665
664
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
666
665
which looks prettier than ``at(x, slice(start, stop, step))``
667
666
and feels more intuitive coming from the JAX documentation.
@@ -680,8 +679,8 @@ def _common(
680
679
copy : bool | None = True ,
681
680
xp : ModuleType | None = None ,
682
681
_is_update : bool = True ,
683
- ** kwargs : Any ,
684
- ) -> tuple [Any , None ] | tuple [None , Array ]:
682
+ ** kwargs : Untyped ,
683
+ ) -> tuple [Untyped , None ] | tuple [None , Array ]:
685
684
"""Perform common prepocessing.
686
685
687
686
Returns
@@ -709,11 +708,11 @@ def _common(
709
708
if not writeable :
710
709
msg = "Cannot modify parameter in place"
711
710
raise ValueError (msg )
712
- elif copy is None :
711
+ elif copy is None : # type: ignore[redundant-expr]
713
712
writeable = is_writeable_array (x )
714
713
copy = _is_update and not writeable
715
714
else :
716
- msg = f"Invalid value for copy: { copy !r} " # type: ignore[unreachable]
715
+ msg = f"Invalid value for copy: { copy !r} " # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
717
716
raise ValueError (msg )
718
717
719
718
if copy :
@@ -744,7 +743,7 @@ def _common(
744
743
745
744
return None , x
746
745
747
- def get (self , ** kwargs : Any ) -> Any :
746
+ def get (self , ** kwargs : Untyped ) -> Untyped :
748
747
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
749
748
that the output is either a copy or a view; it also allows passing
750
749
keyword arguments to the backend.
@@ -769,7 +768,7 @@ def get(self, **kwargs: Any) -> Any:
769
768
assert x is not None
770
769
return x [self .idx ]
771
770
772
- def set (self , y : Array , / , ** kwargs : Any ) -> Array :
771
+ def set (self , y : Array , / , ** kwargs : Untyped ) -> Array :
773
772
"""Apply ``x[idx] = y`` and return the update array"""
774
773
res , x = self ._common ("set" , y , ** kwargs )
775
774
if res is not None :
@@ -784,7 +783,7 @@ def _iop(
784
783
elwise_op : Callable [[Array , Array ], Array ],
785
784
y : Array ,
786
785
/ ,
787
- ** kwargs : Any ,
786
+ ** kwargs : Untyped ,
788
787
) -> Array :
789
788
"""x[idx] += y or equivalent in-place operation on a subset of x
790
789
@@ -802,33 +801,33 @@ def _iop(
802
801
x [self .idx ] = elwise_op (x [self .idx ], y )
803
802
return x
804
803
805
- def add (self , y : Array , / , ** kwargs : Any ) -> Array :
804
+ def add (self , y : Array , / , ** kwargs : Untyped ) -> Array :
806
805
"""Apply ``x[idx] += y`` and return the updated array"""
807
806
return self ._iop ("add" , operator .add , y , ** kwargs )
808
807
809
- def subtract (self , y : Array , / , ** kwargs : Any ) -> Array :
808
+ def subtract (self , y : Array , / , ** kwargs : Untyped ) -> Array :
810
809
"""Apply ``x[idx] -= y`` and return the updated array"""
811
810
return self ._iop ("subtract" , operator .sub , y , ** kwargs )
812
811
813
- def multiply (self , y : Array , / , ** kwargs : Any ) -> Array :
812
+ def multiply (self , y : Array , / , ** kwargs : Untyped ) -> Array :
814
813
"""Apply ``x[idx] *= y`` and return the updated array"""
815
814
return self ._iop ("multiply" , operator .mul , y , ** kwargs )
816
815
817
- def divide (self , y : Array , / , ** kwargs : Any ) -> Array :
816
+ def divide (self , y : Array , / , ** kwargs : Untyped ) -> Array :
818
817
"""Apply ``x[idx] /= y`` and return the updated array"""
819
818
return self ._iop ("divide" , operator .truediv , y , ** kwargs )
820
819
821
- def power (self , y : Array , / , ** kwargs : Any ) -> Array :
820
+ def power (self , y : Array , / , ** kwargs : Untyped ) -> Array :
822
821
"""Apply ``x[idx] **= y`` and return the updated array"""
823
822
return self ._iop ("power" , operator .pow , y , ** kwargs )
824
823
825
- def min (self , y : Array , / , ** kwargs : Any ) -> Array :
824
+ def min (self , y : Array , / , ** kwargs : Untyped ) -> Array :
826
825
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
827
826
xp = array_namespace (self .x )
828
827
y = xp .asarray (y )
829
828
return self ._iop ("min" , xp .minimum , y , ** kwargs )
830
829
831
- def max (self , y : Array , / , ** kwargs : Any ) -> Array :
830
+ def max (self , y : Array , / , ** kwargs : Untyped ) -> Array :
832
831
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
833
832
xp = array_namespace (self .x )
834
833
y = xp .asarray (y )
0 commit comments