5
5
import cupy as cp
6
6
7
7
from ..common import _aliases
8
+ from ..common import _linalg
8
9
9
10
from .._internal import get_xp
10
11
73
74
else :
74
75
isdtype = get_xp (cp )(_aliases .isdtype )
75
76
77
+
78
+ cross = get_xp (cp )(_linalg .cross )
79
+ outer = get_xp (cp )(_linalg .outer )
80
+ EighResult = _linalg .EighResult
81
+ QRResult = _linalg .QRResult
82
+ SlogdetResult = _linalg .SlogdetResult
83
+ SVDResult = _linalg .SVDResult
84
+ eigh = get_xp (cp )(_linalg .eigh )
85
+ qr = get_xp (cp )(_linalg .qr )
86
+ slogdet = get_xp (cp )(_linalg .slogdet )
87
+ svd = get_xp (cp )(_linalg .svd )
88
+ cholesky = get_xp (cp )(_linalg .cholesky )
89
+ matrix_rank = get_xp (cp )(_linalg .matrix_rank )
90
+ pinv = get_xp (cp )(_linalg .pinv )
91
+ matrix_norm = get_xp (cp )(_linalg .matrix_norm )
92
+ svdvals = get_xp (cp )(_linalg .svdvals )
93
+ diagonal = get_xp (cp )(_linalg .diagonal )
94
+ trace = get_xp (cp )(_linalg .trace )
95
+
96
+ # These functions are completely new here. If the library already has them
97
+ # (i.e., numpy 2.0), use the library version instead of our wrapper.
98
+ if hasattr (cp .linalg , 'vector_norm' ):
99
+ vector_norm = cp .linalg .vector_norm
100
+ else :
101
+ vector_norm = get_xp (cp )(_linalg .vector_norm )
0 commit comments