|
1 |
| -from cupy.linalg import * |
2 |
| -# cupy.linalg doesn't have __all__. If it is added, replace this with |
3 |
| -# |
4 |
| -# from cupy.linalg import __all__ as linalg_all |
5 |
| -_n = {} |
6 |
| -exec('from cupy.linalg import *', _n) |
7 |
| -del _n['__builtins__'] |
8 |
| -linalg_all = list(_n) |
9 |
| -del _n |
| 1 | +import cupy as cp |
| 2 | +from .._internal import _get_all_public_members |
10 | 3 |
|
11 |
| -from ..common import _linalg |
12 |
| -from .._internal import get_xp |
13 |
| -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) |
| 4 | +_cupy_linalg_all = _get_all_public_members(cp.linalg) |
14 | 5 |
|
15 |
| -import cupy as cp |
| 6 | +for name in _cupy_linalg_all: |
| 7 | + globals()[name] = getattr(cp.linalg, name) |
16 | 8 |
|
17 |
| -cross = get_xp(cp)(_linalg.cross) |
18 |
| -outer = get_xp(cp)(_linalg.outer) |
19 |
| -EighResult = _linalg.EighResult |
20 |
| -QRResult = _linalg.QRResult |
21 |
| -SlogdetResult = _linalg.SlogdetResult |
22 |
| -SVDResult = _linalg.SVDResult |
23 |
| -eigh = get_xp(cp)(_linalg.eigh) |
24 |
| -qr = get_xp(cp)(_linalg.qr) |
25 |
| -slogdet = get_xp(cp)(_linalg.slogdet) |
26 |
| -svd = get_xp(cp)(_linalg.svd) |
27 |
| -cholesky = get_xp(cp)(_linalg.cholesky) |
28 |
| -matrix_rank = get_xp(cp)(_linalg.matrix_rank) |
29 |
| -pinv = get_xp(cp)(_linalg.pinv) |
30 |
| -matrix_norm = get_xp(cp)(_linalg.matrix_norm) |
31 |
| -svdvals = get_xp(cp)(_linalg.svdvals) |
32 |
| -diagonal = get_xp(cp)(_linalg.diagonal) |
33 |
| -trace = get_xp(cp)(_linalg.trace) |
| 9 | +from ._aliases import ( # noqa: E402 |
| 10 | + EighResult, |
| 11 | + QRResult, |
| 12 | + SlogdetResult, |
| 13 | + SVDResult, |
| 14 | + cholesky, |
| 15 | + cross, |
| 16 | + diagonal, |
| 17 | + eigh, |
| 18 | + matmul, |
| 19 | + matrix_norm, |
| 20 | + matrix_rank, |
| 21 | + matrix_transpose, |
| 22 | + outer, |
| 23 | + pinv, |
| 24 | + qr, |
| 25 | + slogdet, |
| 26 | + svd, |
| 27 | + svdvals, |
| 28 | + tensordot, |
| 29 | + trace, |
| 30 | + vecdot, |
| 31 | + vector_norm, |
| 32 | +) |
34 | 33 |
|
35 |
| -# These functions are completely new here. If the library already has them |
36 |
| -# (i.e., numpy 2.0), use the library version instead of our wrapper. |
37 |
| -if hasattr(cp.linalg, 'vector_norm'): |
38 |
| - vector_norm = cp.linalg.vector_norm |
39 |
| -else: |
40 |
| - vector_norm = get_xp(cp)(_linalg.vector_norm) |
| 34 | +__all__ = [] |
41 | 35 |
|
42 |
| -__all__ = linalg_all + _linalg.__all__ |
| 36 | +__all__ += _cupy_linalg_all |
43 | 37 |
|
44 |
| -del get_xp |
45 |
| -del cp |
46 |
| -del linalg_all |
47 |
| -del _linalg |
| 38 | +__all__ += [ |
| 39 | + "cross", |
| 40 | + "matmul", |
| 41 | + "outer", |
| 42 | + "tensordot", |
| 43 | + "EighResult", |
| 44 | + "QRResult", |
| 45 | + "SlogdetResult", |
| 46 | + "SVDResult", |
| 47 | + "eigh", |
| 48 | + "qr", |
| 49 | + "slogdet", |
| 50 | + "svd", |
| 51 | + "cholesky", |
| 52 | + "matrix_rank", |
| 53 | + "pinv", |
| 54 | + "matrix_norm", |
| 55 | + "matrix_transpose", |
| 56 | + "svdvals", |
| 57 | + "vecdot", |
| 58 | + "vector_norm", |
| 59 | + "diagonal", |
| 60 | + "trace", |
| 61 | +] |
0 commit comments