Skip to content

Commit 7dc025e

Browse files
committed
Move linalg aliases to numpy/_aliases
1 parent 91032fa commit 7dc025e

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

array_api_compat/numpy/_aliases.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
import numpy as np
66

7-
from ..common import _aliases
8-
97
from .._internal import get_xp
8+
from ..common import _aliases
9+
from ..common import _linalg
1010

11-
asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
11+
asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy")
1212
asarray.__doc__ = _aliases._asarray.__doc__
1313

1414
bool = np.bool_
@@ -64,11 +64,37 @@
6464

6565
# These functions are completely new here. If the library already has them
6666
# (i.e., numpy 2.0), use the library version instead of our wrapper.
67-
if hasattr(np, 'vecdot'):
67+
if hasattr(np, "vecdot"):
6868
vecdot = np.vecdot
6969
else:
7070
vecdot = get_xp(np)(_aliases.vecdot)
71-
if hasattr(np, 'isdtype'):
71+
if hasattr(np, "isdtype"):
7272
isdtype = np.isdtype
7373
else:
7474
isdtype = get_xp(np)(_aliases.isdtype)
75+
76+
77+
cross = get_xp(np)(_linalg.cross)
78+
outer = get_xp(np)(_linalg.outer)
79+
EighResult = _linalg.EighResult
80+
QRResult = _linalg.QRResult
81+
SlogdetResult = _linalg.SlogdetResult
82+
SVDResult = _linalg.SVDResult
83+
eigh = get_xp(np)(_linalg.eigh)
84+
qr = get_xp(np)(_linalg.qr)
85+
slogdet = get_xp(np)(_linalg.slogdet)
86+
svd = get_xp(np)(_linalg.svd)
87+
cholesky = get_xp(np)(_linalg.cholesky)
88+
matrix_rank = get_xp(np)(_linalg.matrix_rank)
89+
pinv = get_xp(np)(_linalg.pinv)
90+
matrix_norm = get_xp(np)(_linalg.matrix_norm)
91+
svdvals = get_xp(np)(_linalg.svdvals)
92+
diagonal = get_xp(np)(_linalg.diagonal)
93+
trace = get_xp(np)(_linalg.trace)
94+
95+
# These functions are completely new here. If the library already has them
96+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
97+
if hasattr(np.linalg, "vector_norm"):
98+
vector_norm = np.linalg.vector_norm
99+
else:
100+
vector_norm = get_xp(np)(_linalg.vector_norm)

0 commit comments

Comments
 (0)