Skip to content

Commit 4378e1a

Browse files
committed
Update stubs for the latest version of the spec
1 parent 1063bb6 commit 4378e1a

File tree

6 files changed

+24
-21
lines changed

6 files changed

+24
-21
lines changed

array_api_tests/function_stubs/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
__all__ += ['e', 'inf', 'nan', 'pi']
1919

20-
from .creation_functions import arange, asarray, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like
20+
from .creation_functions import arange, asarray, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, tril, triu, zeros, zeros_like
2121

22-
__all__ += ['arange', 'asarray', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'zeros', 'zeros_like']
22+
__all__ += ['arange', 'asarray', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'tril', 'triu', 'zeros', 'zeros_like']
2323

2424
from .data_type_functions import broadcast_arrays, broadcast_to, can_cast, finfo, iinfo, result_type
2525

@@ -29,13 +29,13 @@
2929

3030
__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
3131

32-
from .linear_algebra_functions import einsum, matmul, matrix_transpose, tensordot, vecdot
32+
from .linear_algebra_functions import matmul, matrix_transpose, tensordot, vecdot
3333

34-
__all__ += ['einsum', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot']
34+
__all__ += ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']
3535

36-
from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
36+
from .manipulation_functions import concat, expand_dims, flip, permute_dims, reshape, roll, squeeze, stack
3737

38-
__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
38+
__all__ += ['concat', 'expand_dims', 'flip', 'permute_dims', 'reshape', 'roll', 'squeeze', 'stack']
3939

4040
from .searching_functions import argmax, argmin, nonzero, where
4141

array_api_tests/function_stubs/creation_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,16 @@ def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, d
5050
def ones_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
5151
pass
5252

53+
def tril(x: array, /, *, k: int = 0) -> array:
54+
pass
55+
56+
def triu(x: array, /, *, k: int = 0) -> array:
57+
pass
58+
5359
def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
5460
pass
5561

5662
def zeros_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
5763
pass
5864

59-
__all__ = ['arange', 'asarray', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'zeros', 'zeros_like']
65+
__all__ = ['arange', 'asarray', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'tril', 'triu', 'zeros', 'zeros_like']

array_api_tests/function_stubs/linalg.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,13 @@ def diagonal(x: array, /, *, offset: int = 0) -> array:
2929
def eig():
3030
pass
3131

32-
def eigh(x: array, /, *, upper: bool = False) -> Tuple[array]:
32+
def eigh(x: array, /) -> Tuple[array]:
3333
pass
3434

3535
def eigvals():
3636
pass
3737

38-
def eigvalsh(x: array, /, *, upper: bool = False) -> array:
39-
pass
40-
41-
def einsum():
38+
def eigvalsh(x: array, /) -> array:
4239
pass
4340

4441
def inv(x: array, /) -> array:
@@ -92,4 +89,4 @@ def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
9289
def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf]]] = 2) -> array:
9390
pass
9491

95-
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'einsum', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']
92+
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

array_api_tests/function_stubs/linear_algebra_functions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from ._types import Optional, Tuple, Union, array
1414
from collections.abc import Sequence
1515

16-
def einsum():
17-
pass
18-
1916
def matmul(x1: array, x2: array, /) -> array:
2017
pass
2118

@@ -28,4 +25,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
2825
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
2926
pass
3027

31-
__all__ = ['einsum', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot']
28+
__all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']

array_api_tests/function_stubs/manipulation_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def expand_dims(x: array, /, *, axis: int) -> array:
2121
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
2222
pass
2323

24+
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
25+
pass
26+
2427
def reshape(x: array, /, shape: Tuple[int, ...]) -> array:
2528
pass
2629

@@ -33,4 +36,4 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
3336
def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> array:
3437
pass
3538

36-
__all__ = ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
39+
__all__ = ['concat', 'expand_dims', 'flip', 'permute_dims', 'reshape', 'roll', 'squeeze', 'stack']

array_api_tests/function_stubs/statistical_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from ._types import Optional, Tuple, Union, array
13+
from ._types import Optional, Tuple, Union, array, dtype
1414

1515
def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
1616
pass
@@ -21,13 +21,13 @@ def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, kee
2121
def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
2222
pass
2323

24-
def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
24+
def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array:
2525
pass
2626

2727
def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
2828
pass
2929

30-
def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
30+
def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array:
3131
pass
3232

3333
def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:

0 commit comments

Comments
 (0)