|
1 |
| -from numpy.linalg import * |
2 |
| -from numpy.linalg import __all__ as linalg_all |
3 |
| - |
4 |
| -from ..common import _linalg |
5 |
| -from .._internal import get_xp |
6 |
| -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) |
7 |
| - |
8 |
| -import numpy as np |
9 |
| - |
10 |
| -cross = get_xp(np)(_linalg.cross) |
11 |
| -outer = get_xp(np)(_linalg.outer) |
12 |
| -EighResult = _linalg.EighResult |
13 |
| -QRResult = _linalg.QRResult |
14 |
| -SlogdetResult = _linalg.SlogdetResult |
15 |
| -SVDResult = _linalg.SVDResult |
16 |
| -eigh = get_xp(np)(_linalg.eigh) |
17 |
| -qr = get_xp(np)(_linalg.qr) |
18 |
| -slogdet = get_xp(np)(_linalg.slogdet) |
19 |
| -svd = get_xp(np)(_linalg.svd) |
20 |
| -cholesky = get_xp(np)(_linalg.cholesky) |
21 |
| -matrix_rank = get_xp(np)(_linalg.matrix_rank) |
22 |
| -pinv = get_xp(np)(_linalg.pinv) |
23 |
| -matrix_norm = get_xp(np)(_linalg.matrix_norm) |
24 |
| -svdvals = get_xp(np)(_linalg.svdvals) |
25 |
| -diagonal = get_xp(np)(_linalg.diagonal) |
26 |
| -trace = get_xp(np)(_linalg.trace) |
27 |
| - |
28 |
| -# These functions are completely new here. If the library already has them |
29 |
| -# (i.e., numpy 2.0), use the library version instead of our wrapper. |
30 |
| -if hasattr(np.linalg, 'vector_norm'): |
31 |
| - vector_norm = np.linalg.vector_norm |
32 |
| -else: |
33 |
| - vector_norm = get_xp(np)(_linalg.vector_norm) |
34 |
| - |
35 |
| -__all__ = linalg_all + _linalg.__all__ |
36 |
| - |
37 |
| -del get_xp |
38 |
| -del np |
39 |
| -del linalg_all |
40 |
| -del _linalg |
| 1 | +import numpy as _np |
| 2 | + |
| 3 | +from .._internal import _get_all_public_members |
| 4 | + |
| 5 | +_numpy_linalg_all = _get_all_public_members(_np.linalg) |
| 6 | + |
| 7 | +for _name in _numpy_linalg_all: |
| 8 | + globals()[_name] = getattr(_np.linalg, _name) |
| 9 | + |
| 10 | + |
| 11 | +from ._aliases import ( # noqa: E402 |
| 12 | + EighResult, |
| 13 | + QRResult, |
| 14 | + SlogdetResult, |
| 15 | + SVDResult, |
| 16 | + cholesky, |
| 17 | + cross, |
| 18 | + diagonal, |
| 19 | + eigh, |
| 20 | + matmul, |
| 21 | + matrix_norm, |
| 22 | + matrix_rank, |
| 23 | + matrix_transpose, |
| 24 | + outer, |
| 25 | + pinv, |
| 26 | + qr, |
| 27 | + slogdet, |
| 28 | + svd, |
| 29 | + svdvals, |
| 30 | + tensordot, |
| 31 | + trace, |
| 32 | + vecdot, |
| 33 | + vector_norm, |
| 34 | +) |
| 35 | + |
| 36 | +__all__ = [] |
| 37 | + |
| 38 | +__all__ += _numpy_linalg_all |
| 39 | + |
| 40 | +__all__ += [ |
| 41 | + "EighResult", |
| 42 | + "QRResult", |
| 43 | + "SlogdetResult", |
| 44 | + "SVDResult", |
| 45 | + "cholesky", |
| 46 | + "cross", |
| 47 | + "diagonal", |
| 48 | + "eigh", |
| 49 | + "matmul", |
| 50 | + "matrix_norm", |
| 51 | + "matrix_rank", |
| 52 | + "matrix_transpose", |
| 53 | + "outer", |
| 54 | + "pinv", |
| 55 | + "qr", |
| 56 | + "slogdet", |
| 57 | + "svd", |
| 58 | + "svdvals", |
| 59 | + "tensordot", |
| 60 | + "trace", |
| 61 | + "vecdot", |
| 62 | + "vector_norm", |
| 63 | +] |
0 commit comments