|
1 |
| -from __future__ import annotations |
| 1 | +import torch as _torch |
2 | 2 |
|
3 |
| -from typing import TYPE_CHECKING |
4 |
| -if TYPE_CHECKING: |
5 |
| - import torch |
6 |
| - array = torch.Tensor |
7 |
| - from torch import dtype as Dtype |
8 |
| - from typing import Optional |
| 3 | +from .._internal import _get_all_public_members |
9 | 4 |
|
10 |
| -from torch.linalg import * |
| 5 | +_torch_linalg_all = _get_all_public_members(_torch.linalg) |
11 | 6 |
|
12 |
| -# torch.linalg doesn't define __all__ |
13 |
| -# from torch.linalg import __all__ as linalg_all |
14 |
| -from torch import linalg as torch_linalg |
15 |
| -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] |
| 7 | +for _name in _torch_linalg_all: |
| 8 | + globals()[_name] = getattr(_torch.linalg, _name) |
16 | 9 |
|
17 | 10 | # outer is implemented in torch but aren't in the linalg namespace
|
18 |
| -from torch import outer |
19 |
| -from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum |
20 |
| - |
21 |
| -# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the |
22 |
| -# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 |
23 |
| -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: |
24 |
| - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) |
25 |
| - return torch_linalg.cross(x1, x2, dim=axis) |
26 |
| - |
27 |
| -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: |
28 |
| - from ._aliases import isdtype |
29 |
| - |
30 |
| - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) |
31 |
| - |
32 |
| - # torch.linalg.vecdot doesn't support integer dtypes |
33 |
| - if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): |
34 |
| - if kwargs: |
35 |
| - raise RuntimeError("vecdot kwargs not supported for integral dtypes") |
36 |
| - ndim = max(x1.ndim, x2.ndim) |
37 |
| - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) |
38 |
| - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) |
39 |
| - if x1_shape[axis] != x2_shape[axis]: |
40 |
| - raise ValueError("x1 and x2 must have the same size along the given axis") |
41 |
| - |
42 |
| - x1_, x2_ = torch.broadcast_tensors(x1, x2) |
43 |
| - x1_ = torch.moveaxis(x1_, axis, -1) |
44 |
| - x2_ = torch.moveaxis(x2_, axis, -1) |
45 |
| - |
46 |
| - res = x1_[..., None, :] @ x2_[..., None] |
47 |
| - return res[..., 0, 0] |
48 |
| - return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) |
49 |
| - |
50 |
| -def solve(x1: array, x2: array, /, **kwargs) -> array: |
51 |
| - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) |
52 |
| - return torch.linalg.solve(x1, x2, **kwargs) |
53 |
| - |
54 |
| -# torch.trace doesn't support the offset argument and doesn't support stacking |
55 |
| -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: |
56 |
| - # Use our wrapped sum to make sure it does upcasting correctly |
57 |
| - return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) |
58 |
| - |
59 |
| -__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', |
60 |
| - 'vecdot', 'solve'] |
61 |
| - |
62 |
| -del linalg_all |
| 11 | +outer = _torch.outer |
| 12 | + |
| 13 | +from ._aliases import ( # noqa: E402 |
| 14 | + matrix_transpose, |
| 15 | + solve, |
| 16 | + sum, |
| 17 | + tensordot, |
| 18 | + trace, |
| 19 | + vecdot_linalg as vecdot, |
| 20 | +) |
| 21 | + |
| 22 | +__all__ = [] |
| 23 | + |
| 24 | +__all__ += _torch_linalg_all |
| 25 | + |
| 26 | +__all__ += [ |
| 27 | + "matrix_transpose", |
| 28 | + "solve", |
| 29 | + "sum", |
| 30 | + "outer", |
| 31 | + "trace", |
| 32 | + "tensordot", |
| 33 | + "vecdot", |
| 34 | +] |
0 commit comments