Skip to content

Commit 8926ce7

Browse files
committed
Update the function stubs from the latest version of the spec
TODO: I also need to make corresponding update to the NumPy implementation.
1 parent c1dba80 commit 8926ce7

File tree

6 files changed

+29
-16
lines changed

6 files changed

+29
-16
lines changed

array_api_tests/function_stubs/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
__all__ = []
1111

12-
from .array_object import __abs__, __add__, __and__, __array_namespace__, __bool__, __dlpack__, __dlpack_device__, __eq__, __float__, __floordiv__, __ge__, __getitem__, __gt__, __int__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, __iadd__, __radd__, __iand__, __rand__, __ifloordiv__, __rfloordiv__, __ilshift__, __rlshift__, __imatmul__, __rmatmul__, __imod__, __rmod__, __imul__, __rmul__, __ior__, __ror__, __ipow__, __rpow__, __irshift__, __rrshift__, __isub__, __rsub__, __itruediv__, __rtruediv__, __ixor__, __rxor__, dtype, device, ndim, shape, size, T
12+
from .array_object import __abs__, __add__, __and__, __array_namespace__, __bool__, __dlpack__, __dlpack_device__, __eq__, __float__, __floordiv__, __ge__, __getitem__, __gt__, __index__, __int__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, to_device, __iadd__, __radd__, __iand__, __rand__, __ifloordiv__, __rfloordiv__, __ilshift__, __rlshift__, __imatmul__, __rmatmul__, __imod__, __rmod__, __imul__, __rmul__, __ior__, __ror__, __ipow__, __rpow__, __irshift__, __rrshift__, __isub__, __rsub__, __itruediv__, __rtruediv__, __ixor__, __rxor__, dtype, device, mT, ndim, shape, size, T
1313

14-
__all__ += ['__abs__', '__add__', '__and__', '__array_namespace__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
14+
__all__ += ['__abs__', '__add__', '__and__', '__array_namespace__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__index__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'to_device', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'mT', 'ndim', 'shape', 'size', 'T']
1515

1616
from .constants import e, inf, nan, pi
1717

@@ -29,9 +29,9 @@
2929

3030
__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
3131

32-
from .linear_algebra_functions import einsum, matmul, tensordot, transpose, vecdot
32+
from .linear_algebra_functions import einsum, matmul, matrix_transpose, tensordot, vecdot
3333

34-
__all__ += ['einsum', 'matmul', 'tensordot', 'transpose', 'vecdot']
34+
__all__ += ['einsum', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot']
3535

3636
from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
3737

array_api_tests/function_stubs/array_object.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def __gt__(self: array, other: Union[int, float, array], /) -> array:
9191
"""
9292
pass
9393

94+
def __index__(self: array, /) -> int:
95+
"""
96+
Note: __index__ is a method of the array object.
97+
"""
98+
pass
99+
94100
def __int__(self: array, /) -> int:
95101
"""
96102
Note: __int__ is a method of the array object.
@@ -205,6 +211,12 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
205211
"""
206212
pass
207213

214+
def to_device(self: array, device: device, /) -> array:
215+
"""
216+
Note: to_device is a method of the array object.
217+
"""
218+
pass
219+
208220
def __iadd__(self: array, other: Union[int, float, array], /) -> array:
209221
"""
210222
Note: __iadd__ is a method of the array object.
@@ -367,6 +379,9 @@ def __rxor__(self: array, other: Union[int, bool, array], /) -> array:
367379
# Note: device is an attribute of the array object.
368380
device: device = None
369381

382+
# Note: mT is an attribute of the array object.
383+
mT: array = None
384+
370385
# Note: ndim is an attribute of the array object.
371386
ndim: int = None
372387

@@ -379,4 +394,4 @@ def __rxor__(self: array, other: Union[int, bool, array], /) -> array:
379394
# Note: T is an attribute of the array object.
380395
T: array = None
381396

382-
__all__ = ['__abs__', '__add__', '__and__', '__array_namespace__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
397+
__all__ = ['__abs__', '__add__', '__and__', '__array_namespace__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__index__', '__int__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'to_device', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'mT', 'ndim', 'shape', 'size', 'T']

array_api_tests/function_stubs/linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def matrix_power(x: array, n: int, /) -> array:
5555
def matrix_rank(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
5656
pass
5757

58+
def matrix_transpose(x, /):
59+
pass
60+
5861
def outer(x1: array, x2: array, /) -> array:
5962
pass
6063

@@ -82,13 +85,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
8285
def trace(x: array, /, *, offset: int = 0) -> array:
8386
pass
8487

85-
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
86-
pass
87-
8888
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
8989
pass
9090

9191
def vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
9292
pass
9393

94-
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'einsum', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'transpose', 'vecdot', 'vector_norm']
94+
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'einsum', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

array_api_tests/function_stubs/linear_algebra_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def einsum():
1919
def matmul(x1: array, x2: array, /) -> array:
2020
pass
2121

22-
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
22+
def matrix_transpose(x, /):
2323
pass
2424

25-
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
25+
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
2626
pass
2727

2828
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
2929
pass
3030

31-
__all__ = ['einsum', 'matmul', 'tensordot', 'transpose', 'vecdot']
31+
__all__ = ['einsum', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot']

array_api_tests/special_cases/test_atanh.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
not modify it directly.
88
"""
99

10-
from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, less, one,
11-
zero)
10+
from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, greater, infinity, less, one, zero
1211
from ..hypothesis_helpers import numeric_arrays
1312
from .._array_module import atanh
1413

array_api_tests/special_cases/test_sign.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
not modify it directly.
88
"""
99

10-
from ..array_helpers import (assert_exactly_equal, exactly_equal, greater, less, logical_or, one,
11-
zero)
10+
from ..array_helpers import assert_exactly_equal, exactly_equal, greater, less, logical_or, one, zero
1211
from ..hypothesis_helpers import numeric_arrays
1312
from .._array_module import sign
1413

0 commit comments

Comments
 (0)