Skip to content

Commit 9e9a82c

Browse files
committed
Move linalg aliases to _aliases
1 parent f0bb5f8 commit 9e9a82c

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import cupy as cp
66

77
from ..common import _aliases
8+
from ..common import _linalg
89

910
from .._internal import get_xp
1011

@@ -73,3 +74,28 @@
7374
else:
7475
isdtype = get_xp(cp)(_aliases.isdtype)
7576

77+
78+
cross = get_xp(cp)(_linalg.cross)
79+
outer = get_xp(cp)(_linalg.outer)
80+
EighResult = _linalg.EighResult
81+
QRResult = _linalg.QRResult
82+
SlogdetResult = _linalg.SlogdetResult
83+
SVDResult = _linalg.SVDResult
84+
eigh = get_xp(cp)(_linalg.eigh)
85+
qr = get_xp(cp)(_linalg.qr)
86+
slogdet = get_xp(cp)(_linalg.slogdet)
87+
svd = get_xp(cp)(_linalg.svd)
88+
cholesky = get_xp(cp)(_linalg.cholesky)
89+
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
90+
pinv = get_xp(cp)(_linalg.pinv)
91+
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
92+
svdvals = get_xp(cp)(_linalg.svdvals)
93+
diagonal = get_xp(cp)(_linalg.diagonal)
94+
trace = get_xp(cp)(_linalg.trace)
95+
96+
# These functions are completely new here. If the library already has them
97+
# (i.e., numpy 2.0), use the library version instead of our wrapper.
98+
if hasattr(cp.linalg, 'vector_norm'):
99+
vector_norm = cp.linalg.vector_norm
100+
else:
101+
vector_norm = get_xp(cp)(_linalg.vector_norm)

0 commit comments

Comments
 (0)