Skip to content

Commit b40974c

Browse files
committed
code review
1 parent 44ebd51 commit b40974c

File tree

4 files changed

+33
-34
lines changed

4 files changed

+33
-34
lines changed

array_api_strict/_array_object.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import numpy as np
2525
import numpy.typing as npt
2626

27-
from ._creation_functions import _default, _Default, asarray
27+
from ._creation_functions import _undef, Undef, asarray
2828
from ._dtypes import (
2929
DType,
3030
_all_dtypes,
@@ -101,7 +101,7 @@ class Array:
101101
# Use a custom constructor instead of __init__, as manually initializing
102102
# this class is not supported API.
103103
@classmethod
104-
def _new(cls, x: np.ndarray | np.generic, /, device: Device | None) -> Array:
104+
def _new(cls, x: npt.NDArray[Any], /, device: Device | None) -> Array:
105105
"""
106106
This is a private method for initializing the array API Array
107107
object.
@@ -611,37 +611,37 @@ def __dlpack__(
611611
/,
612612
*,
613613
stream: Any = None,
614-
max_version: tuple[int, int] | None | _Default = _default,
615-
dl_device: tuple[IntEnum, int] | None | _Default = _default,
616-
copy: bool | None | _Default = _default,
614+
max_version: tuple[int, int] | None | Undef = _undef,
615+
dl_device: tuple[IntEnum, int] | None | Undef = _undef,
616+
copy: bool | None | Undef = _undef,
617617
) -> PyCapsule:
618618
"""
619619
Performs the operation __dlpack__.
620620
"""
621621
if get_array_api_strict_flags()['api_version'] < '2023.12':
622-
if max_version is not _default:
622+
if max_version is not _undef:
623623
raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API")
624-
if dl_device is not _default:
624+
if dl_device is not _undef:
625625
raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API")
626-
if copy is not _default:
626+
if copy is not _undef:
627627
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
628628

629629
if np.lib.NumpyVersion(np.__version__) < '2.1.0':
630-
if max_version not in [_default, None]:
630+
if max_version not in [_undef, None]:
631631
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
632-
if dl_device not in [_default, None]:
632+
if dl_device not in [_undef, None]:
633633
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
634-
if copy not in [_default, None]:
634+
if copy not in [_undef, None]:
635635
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
636636

637637
return self._array.__dlpack__(stream=stream)
638638
else:
639639
kwargs = {'stream': stream}
640-
if max_version is not _default:
640+
if max_version is not _undef:
641641
kwargs['max_version'] = max_version
642-
if dl_device is not _default:
642+
if dl_device is not _undef:
643643
kwargs['dl_device'] = dl_device
644-
if copy is not _default:
644+
if copy is not _undef:
645645
kwargs['copy'] = copy
646646
return self._array.__dlpack__(**kwargs)
647647

@@ -678,7 +678,7 @@ def __float__(self) -> float:
678678
res = self._array.__float__()
679679
return res
680680

681-
def __floordiv__(self, other: Array | complex, /) -> Array:
681+
def __floordiv__(self, other: Array | float, /) -> Array:
682682
"""
683683
Performs the operation __floordiv__.
684684
"""
@@ -690,7 +690,7 @@ def __floordiv__(self, other: Array | complex, /) -> Array:
690690
res = self._array.__floordiv__(other._array)
691691
return self.__class__._new(res, device=self.device)
692692

693-
def __ge__(self, other: Array | complex, /) -> Array:
693+
def __ge__(self, other: Array | float, /) -> Array:
694694
"""
695695
Performs the operation __ge__.
696696
"""
@@ -725,7 +725,7 @@ def __getitem__(
725725
res = self._array.__getitem__(np_key)
726726
return self._new(res, device=self.device)
727727

728-
def __gt__(self, other: Array | complex, /) -> Array:
728+
def __gt__(self, other: Array | float, /) -> Array:
729729
"""
730730
Performs the operation __gt__.
731731
"""
@@ -780,7 +780,7 @@ def __iter__(self) -> Iterator[Array]:
780780
# implemented, which implies iteration on 1-D arrays.
781781
return (Array._new(i, device=self.device) for i in self._array)
782782

783-
def __le__(self, other: Array | complex, /) -> Array:
783+
def __le__(self, other: Array | float, /) -> Array:
784784
"""
785785
Performs the operation __le__.
786786
"""
@@ -804,7 +804,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
804804
res = self._array.__lshift__(other._array)
805805
return self.__class__._new(res, device=self.device)
806806

807-
def __lt__(self, other: Array | complex, /) -> Array:
807+
def __lt__(self, other: Array | float, /) -> Array:
808808
"""
809809
Performs the operation __lt__.
810810
"""

array_api_strict/_creation_functions.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
from ._array_object import Array, Device
1717

1818

19-
class _Default(Enum):
20-
DEFAULT = 0
19+
class Undef(Enum):
20+
UNDEF = 0
2121

2222

23-
_default = _Default.DEFAULT
23+
_undef = Undef.UNDEF
2424

2525

2626
@contextmanager
@@ -211,23 +211,23 @@ def from_dlpack(
211211
x: SupportsDLPack,
212212
/,
213213
*,
214-
device: Device | _Default | None = _default,
215-
copy: bool | _Default | None = _default,
214+
device: Device | Undef | None = _undef,
215+
copy: bool | Undef | None = _undef,
216216
) -> Array:
217217
from ._array_object import Array
218218

219219
if get_array_api_strict_flags()['api_version'] < '2023.12':
220-
if device is not _default:
220+
if device is not _undef:
221221
raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API")
222-
if copy is not _default:
222+
if copy is not _undef:
223223
raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API")
224224

225225
# Going to wait for upstream numpy support
226-
if device is not _default:
226+
if device is not _undef:
227227
_check_device(device)
228228
else:
229229
device = None
230-
if copy not in [_default, None]:
230+
if copy not in [_undef, None]:
231231
raise NotImplementedError("The copy argument to from_dlpack is not yet implemented")
232232

233233
return Array._new(np.from_dlpack(x), device=device)

array_api_strict/_data_type_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from ._array_object import Array, Device
8-
from ._creation_functions import _check_device, _default, _Default
8+
from ._creation_functions import _check_device, _undef, Undef
99
from ._dtypes import (
1010
DType,
1111
_all_dtypes,
@@ -29,9 +29,9 @@ def astype(
2929
*,
3030
copy: bool = True,
3131
# _default is used to emulate the device argument not existing in 2022.12
32-
device: Device | _Default | None = _default,
32+
device: Device | Undef | None = _undef,
3333
) -> Array:
34-
if device is not _default:
34+
if device is not _undef:
3535
if get_array_api_strict_flags()['api_version'] >= '2023.12':
3636
_check_device(device)
3737
else:

array_api_strict/_dtypes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import builtins
34
import warnings
45
from typing import Any
56

@@ -9,8 +10,6 @@
910
# Note: we wrap the NumPy dtype objects in a bare class, so that none of the
1011
# additional methods and behaviors of NumPy dtype objects are exposed.
1112

12-
py_bool = bool
13-
1413

1514
class DType:
1615
_np_dtype: np.dtype[Any]
@@ -22,7 +21,7 @@ def __init__(self, np_dtype: npt.DTypeLike):
2221
def __repr__(self) -> str:
2322
return f"array_api_strict.{self._np_dtype.name}"
2423

25-
def __eq__(self, other: object) -> py_bool:
24+
def __eq__(self, other: object) -> builtins.bool:
2625
# See https://github.com/numpy/numpy/pull/25370/files#r1423259515.
2726
# Avoid the user error of array_api_strict.float32 == numpy.float32,
2827
# which gives False. Making == error is probably too egregious, so

0 commit comments

Comments
 (0)