Skip to content

Commit 306ed04

Browse files
committed
Fix ruff errors in numpy/linalg
1 parent 7dc025e commit 306ed04

File tree

1 file changed

+63
-40
lines changed

1 file changed

+63
-40
lines changed

array_api_compat/numpy/linalg.py

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,63 @@
1-
from numpy.linalg import *
2-
from numpy.linalg import __all__ as linalg_all
3-
4-
from ..common import _linalg
5-
from .._internal import get_xp
6-
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)
7-
8-
import numpy as np
9-
10-
cross = get_xp(np)(_linalg.cross)
11-
outer = get_xp(np)(_linalg.outer)
12-
EighResult = _linalg.EighResult
13-
QRResult = _linalg.QRResult
14-
SlogdetResult = _linalg.SlogdetResult
15-
SVDResult = _linalg.SVDResult
16-
eigh = get_xp(np)(_linalg.eigh)
17-
qr = get_xp(np)(_linalg.qr)
18-
slogdet = get_xp(np)(_linalg.slogdet)
19-
svd = get_xp(np)(_linalg.svd)
20-
cholesky = get_xp(np)(_linalg.cholesky)
21-
matrix_rank = get_xp(np)(_linalg.matrix_rank)
22-
pinv = get_xp(np)(_linalg.pinv)
23-
matrix_norm = get_xp(np)(_linalg.matrix_norm)
24-
svdvals = get_xp(np)(_linalg.svdvals)
25-
diagonal = get_xp(np)(_linalg.diagonal)
26-
trace = get_xp(np)(_linalg.trace)
27-
28-
# These functions are completely new here. If the library already has them
29-
# (i.e., numpy 2.0), use the library version instead of our wrapper.
30-
if hasattr(np.linalg, 'vector_norm'):
31-
vector_norm = np.linalg.vector_norm
32-
else:
33-
vector_norm = get_xp(np)(_linalg.vector_norm)
34-
35-
__all__ = linalg_all + _linalg.__all__
36-
37-
del get_xp
38-
del np
39-
del linalg_all
40-
del _linalg
1+
import numpy as _np
2+
3+
from .._internal import _get_all_public_members
4+
5+
_numpy_linalg_all = _get_all_public_members(_np.linalg)
6+
7+
for _name in _numpy_linalg_all:
8+
globals()[_name] = getattr(_np.linalg, _name)
9+
10+
11+
from ._aliases import ( # noqa: E402
12+
EighResult,
13+
QRResult,
14+
SlogdetResult,
15+
SVDResult,
16+
cholesky,
17+
cross,
18+
diagonal,
19+
eigh,
20+
matmul,
21+
matrix_norm,
22+
matrix_rank,
23+
matrix_transpose,
24+
outer,
25+
pinv,
26+
qr,
27+
slogdet,
28+
svd,
29+
svdvals,
30+
tensordot,
31+
trace,
32+
vecdot,
33+
vector_norm,
34+
)
35+
36+
__all__ = []
37+
38+
__all__ += _numpy_linalg_all
39+
40+
__all__ += [
41+
"EighResult",
42+
"QRResult",
43+
"SlogdetResult",
44+
"SVDResult",
45+
"cholesky",
46+
"cross",
47+
"diagonal",
48+
"eigh",
49+
"matmul",
50+
"matrix_norm",
51+
"matrix_rank",
52+
"matrix_transpose",
53+
"outer",
54+
"pinv",
55+
"qr",
56+
"slogdet",
57+
"svd",
58+
"svdvals",
59+
"tensordot",
60+
"trace",
61+
"vecdot",
62+
"vector_norm",
63+
]

0 commit comments

Comments
 (0)