24
24
import numpy as np
25
25
import numpy .typing as npt
26
26
27
- from ._creation_functions import _default , _Default , asarray
27
+ from ._creation_functions import _undef , Undef , asarray
28
28
from ._dtypes import (
29
29
DType ,
30
30
_all_dtypes ,
@@ -101,7 +101,7 @@ class Array:
101
101
# Use a custom constructor instead of __init__, as manually initializing
102
102
# this class is not supported API.
103
103
@classmethod
104
- def _new (cls , x : np . ndarray | np . generic , / , device : Device | None ) -> Array :
104
+ def _new (cls , x : npt . NDArray [ Any ] , / , device : Device | None ) -> Array :
105
105
"""
106
106
This is a private method for initializing the array API Array
107
107
object.
@@ -611,37 +611,37 @@ def __dlpack__(
611
611
/ ,
612
612
* ,
613
613
stream : Any = None ,
614
- max_version : tuple [int , int ] | None | _Default = _default ,
615
- dl_device : tuple [IntEnum , int ] | None | _Default = _default ,
616
- copy : bool | None | _Default = _default ,
614
+ max_version : tuple [int , int ] | None | Undef = _undef ,
615
+ dl_device : tuple [IntEnum , int ] | None | Undef = _undef ,
616
+ copy : bool | None | Undef = _undef ,
617
617
) -> PyCapsule :
618
618
"""
619
619
Performs the operation __dlpack__.
620
620
"""
621
621
if get_array_api_strict_flags ()['api_version' ] < '2023.12' :
622
- if max_version is not _default :
622
+ if max_version is not _undef :
623
623
raise ValueError ("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API" )
624
- if dl_device is not _default :
624
+ if dl_device is not _undef :
625
625
raise ValueError ("The device argument to __dlpack__ requires at least version 2023.12 of the array API" )
626
- if copy is not _default :
626
+ if copy is not _undef :
627
627
raise ValueError ("The copy argument to __dlpack__ requires at least version 2023.12 of the array API" )
628
628
629
629
if np .lib .NumpyVersion (np .__version__ ) < '2.1.0' :
630
- if max_version not in [_default , None ]:
630
+ if max_version not in [_undef , None ]:
631
631
raise NotImplementedError ("The max_version argument to __dlpack__ is not yet implemented" )
632
- if dl_device not in [_default , None ]:
632
+ if dl_device not in [_undef , None ]:
633
633
raise NotImplementedError ("The device argument to __dlpack__ is not yet implemented" )
634
- if copy not in [_default , None ]:
634
+ if copy not in [_undef , None ]:
635
635
raise NotImplementedError ("The copy argument to __dlpack__ is not yet implemented" )
636
636
637
637
return self ._array .__dlpack__ (stream = stream )
638
638
else :
639
639
kwargs = {'stream' : stream }
640
- if max_version is not _default :
640
+ if max_version is not _undef :
641
641
kwargs ['max_version' ] = max_version
642
- if dl_device is not _default :
642
+ if dl_device is not _undef :
643
643
kwargs ['dl_device' ] = dl_device
644
- if copy is not _default :
644
+ if copy is not _undef :
645
645
kwargs ['copy' ] = copy
646
646
return self ._array .__dlpack__ (** kwargs )
647
647
@@ -678,7 +678,7 @@ def __float__(self) -> float:
678
678
res = self ._array .__float__ ()
679
679
return res
680
680
681
- def __floordiv__ (self , other : Array | complex , / ) -> Array :
681
+ def __floordiv__ (self , other : Array | float , / ) -> Array :
682
682
"""
683
683
Performs the operation __floordiv__.
684
684
"""
@@ -690,7 +690,7 @@ def __floordiv__(self, other: Array | complex, /) -> Array:
690
690
res = self ._array .__floordiv__ (other ._array )
691
691
return self .__class__ ._new (res , device = self .device )
692
692
693
- def __ge__ (self , other : Array | complex , / ) -> Array :
693
+ def __ge__ (self , other : Array | float , / ) -> Array :
694
694
"""
695
695
Performs the operation __ge__.
696
696
"""
@@ -725,7 +725,7 @@ def __getitem__(
725
725
res = self ._array .__getitem__ (np_key )
726
726
return self ._new (res , device = self .device )
727
727
728
- def __gt__ (self , other : Array | complex , / ) -> Array :
728
+ def __gt__ (self , other : Array | float , / ) -> Array :
729
729
"""
730
730
Performs the operation __gt__.
731
731
"""
@@ -780,7 +780,7 @@ def __iter__(self) -> Iterator[Array]:
780
780
# implemented, which implies iteration on 1-D arrays.
781
781
return (Array ._new (i , device = self .device ) for i in self ._array )
782
782
783
- def __le__ (self , other : Array | complex , / ) -> Array :
783
+ def __le__ (self , other : Array | float , / ) -> Array :
784
784
"""
785
785
Performs the operation __le__.
786
786
"""
@@ -804,7 +804,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
804
804
res = self ._array .__lshift__ (other ._array )
805
805
return self .__class__ ._new (res , device = self .device )
806
806
807
- def __lt__ (self , other : Array | complex , / ) -> Array :
807
+ def __lt__ (self , other : Array | float , / ) -> Array :
808
808
"""
809
809
Performs the operation __lt__.
810
810
"""
0 commit comments