1
1
from __future__ import annotations
2
2
3
3
from ._array_object import ndarray
4
+ from ._dtypes import _numeric_dtypes
4
5
5
6
from typing import TYPE_CHECKING
6
7
if TYPE_CHECKING :
7
- from ._types import Literal , Optional , Tuple , Union , array
8
+ from ._types import Optional , Sequence , Tuple , Union , array
8
9
9
10
import numpy as np
10
11
11
- # def cholesky():
12
- # """
13
- # Array API compatible wrapper for :py:func:`np.cholesky <numpy.cholesky>`.
14
- #
15
- # See its docstring for more information.
16
- # """
17
- # return np.cholesky()
18
-
19
- def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
20
- """
21
- Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
12
+ # einsum is not yet implemented in the array API spec.
22
13
23
- See its docstring for more information.
24
- """
25
- return ndarray ._new (np .cross (x1 ._array , x2 ._array , axis = axis ))
26
-
27
- def det (x : array , / ) -> array :
28
- """
29
- Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
30
-
31
- See its docstring for more information.
32
- """
33
- # Note: this function is being imported from a nondefault namespace
34
- return ndarray ._new (np .linalg .det (x ._array ))
35
-
36
- def diagonal (x : array , / , * , axis1 : int = 0 , axis2 : int = 1 , offset : int = 0 ) -> array :
37
- """
38
- Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
39
-
40
- See its docstring for more information.
41
- """
42
- return ndarray ._new (np .diagonal (x ._array , axis1 = axis1 , axis2 = axis2 , offset = offset ))
43
-
44
- # def dot():
45
- # """
46
- # Array API compatible wrapper for :py:func:`np.dot <numpy.dot>`.
47
- #
48
- # See its docstring for more information.
49
- # """
50
- # return np.dot()
51
- #
52
- # def eig():
53
- # """
54
- # Array API compatible wrapper for :py:func:`np.eig <numpy.eig>`.
55
- #
56
- # See its docstring for more information.
57
- # """
58
- # return np.eig()
59
- #
60
- # def eigvalsh():
61
- # """
62
- # Array API compatible wrapper for :py:func:`np.eigvalsh <numpy.eigvalsh>`.
63
- #
64
- # See its docstring for more information.
65
- # """
66
- # return np.eigvalsh()
67
- #
68
14
# def einsum():
69
15
# """
70
16
# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
@@ -73,114 +19,27 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) ->
73
19
# """
74
20
# return np.einsum()
75
21
76
- def inv (x : array , / ) -> array :
77
- """
78
- Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
79
-
80
- See its docstring for more information.
81
- """
82
- # Note: this function is being imported from a nondefault namespace
83
- return ndarray ._new (np .linalg .inv (x ._array ))
84
-
85
- # def lstsq():
86
- # """
87
- # Array API compatible wrapper for :py:func:`np.lstsq <numpy.lstsq>`.
88
- #
89
- # See its docstring for more information.
90
- # """
91
- # return np.lstsq()
92
- #
93
- # def matmul():
94
- # """
95
- # Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
96
- #
97
- # See its docstring for more information.
98
- # """
99
- # return np.matmul()
100
- #
101
- # def matrix_power():
102
- # """
103
- # Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
104
- #
105
- # See its docstring for more information.
106
- # """
107
- # return np.matrix_power()
108
- #
109
- # def matrix_rank():
110
- # """
111
- # Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
112
- #
113
- # See its docstring for more information.
114
- # """
115
- # return np.matrix_rank()
116
-
117
- def norm (x : array , / , * , axis : Optional [Union [int , Tuple [int , int ]]] = None , keepdims : bool = False , ord : Optional [Union [int , float , Literal [np .inf , - np .inf , 'fro' , 'nuc' ]]] = None ) -> array :
118
- """
119
- Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
120
-
121
- See its docstring for more information.
122
- """
123
- # Note: this is different from the default behavior
124
- if axis == None and x .ndim > 2 :
125
- x = ndarray ._new (x ._array .flatten ())
126
- # Note: this function is being imported from a nondefault namespace
127
- return ndarray ._new (np .linalg .norm (x ._array , axis = axis , keepdims = keepdims , ord = ord ))
128
-
129
- def outer (x1 : array , x2 : array , / ) -> array :
22
+ def matmul (x1 : array , x2 : array , / ) -> array :
130
23
"""
131
- Array API compatible wrapper for :py:func:`np.outer <numpy.outer >`.
24
+ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul >`.
132
25
133
26
See its docstring for more information.
134
27
"""
135
- return ndarray ._new (np .outer (x1 ._array , x2 ._array ))
28
+ # Note: the restriction to numeric dtypes only is different from
29
+ # np.matmul.
30
+ if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
31
+ raise TypeError ('Only numeric dtypes are allowed in matmul' )
136
32
137
- # def pinv():
138
- # """
139
- # Array API compatible wrapper for :py:func:`np.pinv <numpy.pinv>`.
140
- #
141
- # See its docstring for more information.
142
- # """
143
- # return np.pinv()
144
- #
145
- # def qr():
146
- # """
147
- # Array API compatible wrapper for :py:func:`np.qr <numpy.qr>`.
148
- #
149
- # See its docstring for more information.
150
- # """
151
- # return np.qr()
152
- #
153
- # def slogdet():
154
- # """
155
- # Array API compatible wrapper for :py:func:`np.slogdet <numpy.slogdet>`.
156
- #
157
- # See its docstring for more information.
158
- # """
159
- # return np.slogdet()
160
- #
161
- # def solve():
162
- # """
163
- # Array API compatible wrapper for :py:func:`np.solve <numpy.solve>`.
164
- #
165
- # See its docstring for more information.
166
- # """
167
- # return np.solve()
168
- #
169
- # def svd():
170
- # """
171
- # Array API compatible wrapper for :py:func:`np.svd <numpy.svd>`.
172
- #
173
- # See its docstring for more information.
174
- # """
175
- # return np.svd()
33
+ return ndarray ._new (np .matmul (x1 ._array , x2 ._array ))
176
34
177
- def trace (x : array , / , * , axis1 : int = 0 , axis2 : int = 1 , offset : int = 0 ) -> array :
178
- """
179
- Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
35
+ # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
36
+ def tensordot (x1 : array , x2 : array , / , * , axes : Union [int , Tuple [Sequence [int ], Sequence [int ]]] = 2 ) -> array :
37
+ # Note: the restriction to numeric dtypes only is different from
38
+ # np.tensordot.
39
+ if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
40
+ raise TypeError ('Only numeric dtypes are allowed in tensordot' )
180
41
181
- See its docstring for more information.
182
- """
183
- return ndarray ._new (np .asarray (np .trace (x ._array , axis1 = axis1 , axis2 = axis2 , offset = offset )))
42
+ return ndarray ._new (np .tensordot (x1 ._array , x2 ._array , axes = axes ))
184
43
185
44
def transpose (x : array , / , * , axes : Optional [Tuple [int , ...]] = None ) -> array :
186
45
"""
@@ -189,3 +48,9 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
189
48
See its docstring for more information.
190
49
"""
191
50
return ndarray ._new (np .transpose (x ._array , axes = axes ))
51
+
52
+ # Note: vecdot is not in NumPy
53
+ def vecdot (x1 : array , x2 : array , / , * , axis : Optional [int ] = None ) -> array :
54
+ if axis is None :
55
+ axis = - 1
56
+ return tensordot (x1 , x2 , axes = ((axis ,), (axis ,)))
0 commit comments