Skip to content

Commit 74478e2

Browse files
committed
Use better type signatures in the array API module
This includes returning custom dataclasses for finfo and iinfo that only contain the properties required by the array API specification.
1 parent 29b7a69 commit 74478e2

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

numpy/_array_api/_array_object.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
396396
res = self._array.__le__(other._array)
397397
return self.__class__._new(res)
398398

399-
def __len__(self, /):
399+
# Note: __len__ may end up being removed from the array API spec.
400+
def __len__(self, /) -> int:
400401
"""
401402
Performs the operation __len__.
402403
"""
@@ -843,7 +844,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
843844
return self.__class__._new(res)
844845

845846
@property
846-
def dtype(self):
847+
def dtype(self) -> Dtype:
847848
"""
848849
Array API compatible wrapper for :py:meth:`np.ndaray.dtype <numpy.ndarray.dtype>`.
849850
@@ -852,7 +853,7 @@ def dtype(self):
852853
return self._array.dtype
853854

854855
@property
855-
def device(self):
856+
def device(self) -> Device:
856857
"""
857858
Array API compatible wrapper for :py:meth:`np.ndaray.device <numpy.ndarray.device>`.
858859
@@ -862,7 +863,7 @@ def device(self):
862863
raise NotImplementedError("The device attribute is not yet implemented")
863864

864865
@property
865-
def ndim(self):
866+
def ndim(self) -> int:
866867
"""
867868
Array API compatible wrapper for :py:meth:`np.ndaray.ndim <numpy.ndarray.ndim>`.
868869
@@ -871,7 +872,7 @@ def ndim(self):
871872
return self._array.ndim
872873

873874
@property
874-
def shape(self):
875+
def shape(self) -> Tuple[int, ...]:
875876
"""
876877
Array API compatible wrapper for :py:meth:`np.ndaray.shape <numpy.ndarray.shape>`.
877878
@@ -880,7 +881,7 @@ def shape(self):
880881
return self._array.shape
881882

882883
@property
883-
def size(self):
884+
def size(self) -> int:
884885
"""
885886
Array API compatible wrapper for :py:meth:`np.ndaray.size <numpy.ndarray.size>`.
886887
@@ -889,7 +890,7 @@ def size(self):
889890
return self._array.size
890891

891892
@property
892-
def T(self):
893+
def T(self) -> Array:
893894
"""
894895
Array API compatible wrapper for :py:meth:`np.ndaray.T <numpy.ndarray.T>`.
895896

numpy/_array_api/_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import numpy as np
1212

13-
def asarray(obj: Union[float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array:
13+
def asarray(obj: Union[Array, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array:
1414
"""
1515
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
1616

numpy/_array_api/_data_type_functions.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44

5+
from dataclasses import dataclass
56
from typing import TYPE_CHECKING
67
if TYPE_CHECKING:
78
from ._types import List, Tuple, Union, Dtype
@@ -38,21 +39,53 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
3839
from_ = from_._array
3940
return np.can_cast(from_, to)
4041

42+
# These are internal objects for the return types of finfo and iinfo, since
43+
# the NumPy versions contain extra data that isn't part of the spec.
44+
@dataclass
45+
class finfo_object:
46+
bits: int
47+
# Note: The types of the float data here are float, whereas in NumPy they
48+
# are scalars of the corresponding float dtype.
49+
eps: float
50+
max: float
51+
min: float
52+
# Note: smallest_normal is part of the array API spec, but cannot be used
53+
# until https://github.com/numpy/numpy/pull/18536 is merged.
54+
55+
# smallest_normal: float
56+
57+
@dataclass
58+
class iinfo_object:
59+
bits: int
60+
max: int
61+
min: int
62+
4163
def finfo(type: Union[Dtype, Array], /) -> finfo_object:
4264
"""
4365
Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
4466
4567
See its docstring for more information.
4668
"""
47-
return np.finfo(type)
69+
fi = np.finfo(type)
70+
# Note: The types of the float data here are float, whereas in NumPy they
71+
# are scalars of the corresponding float dtype.
72+
return finfo_object(
73+
fi.bits,
74+
float(fi.eps),
75+
float(fi.max),
76+
float(fi.min),
77+
# TODO: Uncomment this when #18536 is merged.
78+
# float(fi.smallest_normal),
79+
)
4880

4981
def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
5082
"""
5183
Array API compatible wrapper for :py:func:`np.iinfo <numpy.iinfo>`.
5284
5385
See its docstring for more information.
5486
"""
55-
return np.iinfo(type)
87+
ii = np.iinfo(type)
88+
return iinfo_object(ii.bits, ii.max, ii.min)
5689

5790
def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
5891
"""

numpy/_array_api/_manipulation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
# Note: the function name is different here
10-
def concat(arrays: Tuple[Array, ...], /, *, axis: Optional[int] = 0) -> Array:
10+
def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
1111
"""
1212
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
1313
@@ -56,7 +56,7 @@ def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
5656
"""
5757
return Array._new(np.squeeze(x._array, axis=axis))
5858

59-
def stack(arrays: Tuple[Array, ...], /, *, axis: int = 0) -> Array:
59+
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
6060
"""
6161
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
6262

0 commit comments

Comments
 (0)