Skip to content

Commit cad21e9

Browse files
committed
Update the linear algebra functions in the array API namespace
For now, only the functions in from the main spec namespace are implemented. The remaining linear algebra functions are part of an extension in the spec, and will be implemented in a future pull request. This is because the linear algebra functions are relatively complicated, so they will be easier to review separately. This also updates those functions that do remain for now to be more compliant with the spec.
1 parent f6015d2 commit cad21e9

File tree

2 files changed

+31
-165
lines changed

2 files changed

+31
-165
lines changed

numpy/_array_api/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
4848
- np.argmin and np.argmax do not implement the keepdims keyword argument.
4949
50-
- Some linear algebra functions in the spec are still a work in progress (to
51-
be added soon). These will be updated once the spec is.
50+
- The linear algebra extension in the spec will be added in a future pull
51+
request.
5252
5353
- Some tests in the test suite are still not fully correct in that they test
5454
all datatypes whereas certain functions are only defined for a subset of
@@ -132,13 +132,14 @@
132132

133133
__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
134134

135-
from ._linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose
135+
# einsum is not yet implemented in the array API spec.
136136

137-
__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']
137+
# from ._linear_algebra_functions import einsum
138+
# __all__ += ['einsum']
138139

139-
# from ._linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose
140-
#
141-
# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
140+
from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
141+
142+
__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot']
142143

143144
from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
144145

Lines changed: 23 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,16 @@
11
from __future__ import annotations
22

33
from ._array_object import ndarray
4+
from ._dtypes import _numeric_dtypes
45

56
from typing import TYPE_CHECKING
67
if TYPE_CHECKING:
7-
from ._types import Literal, Optional, Tuple, Union, array
8+
from ._types import Optional, Sequence, Tuple, Union, array
89

910
import numpy as np
1011

11-
# def cholesky():
12-
# """
13-
# Array API compatible wrapper for :py:func:`np.cholesky <numpy.cholesky>`.
14-
#
15-
# See its docstring for more information.
16-
# """
17-
# return np.cholesky()
18-
19-
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
20-
"""
21-
Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
12+
# einsum is not yet implemented in the array API spec.
2213

23-
See its docstring for more information.
24-
"""
25-
return ndarray._new(np.cross(x1._array, x2._array, axis=axis))
26-
27-
def det(x: array, /) -> array:
28-
"""
29-
Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
30-
31-
See its docstring for more information.
32-
"""
33-
# Note: this function is being imported from a nondefault namespace
34-
return ndarray._new(np.linalg.det(x._array))
35-
36-
def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
37-
"""
38-
Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
39-
40-
See its docstring for more information.
41-
"""
42-
return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset))
43-
44-
# def dot():
45-
# """
46-
# Array API compatible wrapper for :py:func:`np.dot <numpy.dot>`.
47-
#
48-
# See its docstring for more information.
49-
# """
50-
# return np.dot()
51-
#
52-
# def eig():
53-
# """
54-
# Array API compatible wrapper for :py:func:`np.eig <numpy.eig>`.
55-
#
56-
# See its docstring for more information.
57-
# """
58-
# return np.eig()
59-
#
60-
# def eigvalsh():
61-
# """
62-
# Array API compatible wrapper for :py:func:`np.eigvalsh <numpy.eigvalsh>`.
63-
#
64-
# See its docstring for more information.
65-
# """
66-
# return np.eigvalsh()
67-
#
6814
# def einsum():
6915
# """
7016
# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
@@ -73,114 +19,27 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) ->
7319
# """
7420
# return np.einsum()
7521

76-
def inv(x: array, /) -> array:
77-
"""
78-
Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
79-
80-
See its docstring for more information.
81-
"""
82-
# Note: this function is being imported from a nondefault namespace
83-
return ndarray._new(np.linalg.inv(x._array))
84-
85-
# def lstsq():
86-
# """
87-
# Array API compatible wrapper for :py:func:`np.lstsq <numpy.lstsq>`.
88-
#
89-
# See its docstring for more information.
90-
# """
91-
# return np.lstsq()
92-
#
93-
# def matmul():
94-
# """
95-
# Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
96-
#
97-
# See its docstring for more information.
98-
# """
99-
# return np.matmul()
100-
#
101-
# def matrix_power():
102-
# """
103-
# Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
104-
#
105-
# See its docstring for more information.
106-
# """
107-
# return np.matrix_power()
108-
#
109-
# def matrix_rank():
110-
# """
111-
# Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
112-
#
113-
# See its docstring for more information.
114-
# """
115-
# return np.matrix_rank()
116-
117-
def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[np.inf, -np.inf, 'fro', 'nuc']]] = None) -> array:
118-
"""
119-
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
120-
121-
See its docstring for more information.
122-
"""
123-
# Note: this is different from the default behavior
124-
if axis == None and x.ndim > 2:
125-
x = ndarray._new(x._array.flatten())
126-
# Note: this function is being imported from a nondefault namespace
127-
return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord))
128-
129-
def outer(x1: array, x2: array, /) -> array:
22+
def matmul(x1: array, x2: array, /) -> array:
13023
"""
131-
Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
24+
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
13225
13326
See its docstring for more information.
13427
"""
135-
return ndarray._new(np.outer(x1._array, x2._array))
28+
# Note: the restriction to numeric dtypes only is different from
29+
# np.matmul.
30+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
31+
raise TypeError('Only numeric dtypes are allowed in matmul')
13632

137-
# def pinv():
138-
# """
139-
# Array API compatible wrapper for :py:func:`np.pinv <numpy.pinv>`.
140-
#
141-
# See its docstring for more information.
142-
# """
143-
# return np.pinv()
144-
#
145-
# def qr():
146-
# """
147-
# Array API compatible wrapper for :py:func:`np.qr <numpy.qr>`.
148-
#
149-
# See its docstring for more information.
150-
# """
151-
# return np.qr()
152-
#
153-
# def slogdet():
154-
# """
155-
# Array API compatible wrapper for :py:func:`np.slogdet <numpy.slogdet>`.
156-
#
157-
# See its docstring for more information.
158-
# """
159-
# return np.slogdet()
160-
#
161-
# def solve():
162-
# """
163-
# Array API compatible wrapper for :py:func:`np.solve <numpy.solve>`.
164-
#
165-
# See its docstring for more information.
166-
# """
167-
# return np.solve()
168-
#
169-
# def svd():
170-
# """
171-
# Array API compatible wrapper for :py:func:`np.svd <numpy.svd>`.
172-
#
173-
# See its docstring for more information.
174-
# """
175-
# return np.svd()
33+
return ndarray._new(np.matmul(x1._array, x2._array))
17634

177-
def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
178-
"""
179-
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
35+
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
36+
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
37+
# Note: the restriction to numeric dtypes only is different from
38+
# np.tensordot.
39+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
40+
raise TypeError('Only numeric dtypes are allowed in tensordot')
18041

181-
See its docstring for more information.
182-
"""
183-
return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset)))
42+
return ndarray._new(np.tensordot(x1._array, x2._array, axes=axes))
18443

18544
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
18645
"""
@@ -189,3 +48,9 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
18948
See its docstring for more information.
19049
"""
19150
return ndarray._new(np.transpose(x._array, axes=axes))
51+
52+
# Note: vecdot is not in NumPy
53+
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
54+
if axis is None:
55+
axis = -1
56+
return tensordot(x1, x2, axes=((axis,), (axis,)))

0 commit comments

Comments
 (0)