Skip to content

Commit 91032fa

Browse files
committed
Fix ruff errors in cupy/linalg
1 parent 9e9a82c commit 91032fa

File tree

1 file changed

+55
-41
lines changed

1 file changed

+55
-41
lines changed

array_api_compat/cupy/linalg.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,61 @@
1-
from cupy.linalg import *
2-
# cupy.linalg doesn't have __all__. If it is added, replace this with
3-
#
4-
# from cupy.linalg import __all__ as linalg_all
5-
_n = {}
6-
exec('from cupy.linalg import *', _n)
7-
del _n['__builtins__']
8-
linalg_all = list(_n)
9-
del _n
1+
import cupy as cp
2+
from .._internal import _get_all_public_members
103

11-
from ..common import _linalg
12-
from .._internal import get_xp
13-
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)
4+
_cupy_linalg_all = _get_all_public_members(cp.linalg)
145

15-
import cupy as cp
6+
for name in _cupy_linalg_all:
7+
globals()[name] = getattr(cp.linalg, name)
168

17-
cross = get_xp(cp)(_linalg.cross)
18-
outer = get_xp(cp)(_linalg.outer)
19-
EighResult = _linalg.EighResult
20-
QRResult = _linalg.QRResult
21-
SlogdetResult = _linalg.SlogdetResult
22-
SVDResult = _linalg.SVDResult
23-
eigh = get_xp(cp)(_linalg.eigh)
24-
qr = get_xp(cp)(_linalg.qr)
25-
slogdet = get_xp(cp)(_linalg.slogdet)
26-
svd = get_xp(cp)(_linalg.svd)
27-
cholesky = get_xp(cp)(_linalg.cholesky)
28-
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
29-
pinv = get_xp(cp)(_linalg.pinv)
30-
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
31-
svdvals = get_xp(cp)(_linalg.svdvals)
32-
diagonal = get_xp(cp)(_linalg.diagonal)
33-
trace = get_xp(cp)(_linalg.trace)
9+
from ._aliases import ( # noqa: E402
10+
EighResult,
11+
QRResult,
12+
SlogdetResult,
13+
SVDResult,
14+
cholesky,
15+
cross,
16+
diagonal,
17+
eigh,
18+
matmul,
19+
matrix_norm,
20+
matrix_rank,
21+
matrix_transpose,
22+
outer,
23+
pinv,
24+
qr,
25+
slogdet,
26+
svd,
27+
svdvals,
28+
tensordot,
29+
trace,
30+
vecdot,
31+
vector_norm,
32+
)
3433

35-
# These functions are completely new here. If the library already has them
36-
# (i.e., numpy 2.0), use the library version instead of our wrapper.
37-
if hasattr(cp.linalg, 'vector_norm'):
38-
vector_norm = cp.linalg.vector_norm
39-
else:
40-
vector_norm = get_xp(cp)(_linalg.vector_norm)
34+
__all__ = []
4135

42-
__all__ = linalg_all + _linalg.__all__
36+
__all__ += _cupy_linalg_all
4337

44-
del get_xp
45-
del cp
46-
del linalg_all
47-
del _linalg
38+
__all__ += [
39+
"cross",
40+
"matmul",
41+
"outer",
42+
"tensordot",
43+
"EighResult",
44+
"QRResult",
45+
"SlogdetResult",
46+
"SVDResult",
47+
"eigh",
48+
"qr",
49+
"slogdet",
50+
"svd",
51+
"cholesky",
52+
"matrix_rank",
53+
"pinv",
54+
"matrix_norm",
55+
"matrix_transpose",
56+
"svdvals",
57+
"vecdot",
58+
"vector_norm",
59+
"diagonal",
60+
"trace",
61+
]

0 commit comments

Comments
 (0)