Skip to content

Commit fb447e6

Browse files
committed
Fix ruff errors in torch/linalg
1 parent 59f08b2 commit fb447e6

File tree

1 file changed

+29
-57
lines changed

1 file changed

+29
-57
lines changed

array_api_compat/torch/linalg.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,34 @@
1-
from __future__ import annotations
1+
import torch as _torch
22

3-
from typing import TYPE_CHECKING
4-
if TYPE_CHECKING:
5-
import torch
6-
array = torch.Tensor
7-
from torch import dtype as Dtype
8-
from typing import Optional
3+
from .._internal import _get_all_public_members
94

10-
from torch.linalg import *
5+
_torch_linalg_all = _get_all_public_members(_torch.linalg)
116

12-
# torch.linalg doesn't define __all__
13-
# from torch.linalg import __all__ as linalg_all
14-
from torch import linalg as torch_linalg
15-
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
7+
for _name in _torch_linalg_all:
8+
globals()[_name] = getattr(_torch.linalg, _name)
169

1710
# outer is implemented in torch but aren't in the linalg namespace
18-
from torch import outer
19-
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum
20-
21-
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
22-
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
23-
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
24-
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
25-
return torch_linalg.cross(x1, x2, dim=axis)
26-
27-
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
28-
from ._aliases import isdtype
29-
30-
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
31-
32-
# torch.linalg.vecdot doesn't support integer dtypes
33-
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
34-
if kwargs:
35-
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
36-
ndim = max(x1.ndim, x2.ndim)
37-
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
38-
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
39-
if x1_shape[axis] != x2_shape[axis]:
40-
raise ValueError("x1 and x2 must have the same size along the given axis")
41-
42-
x1_, x2_ = torch.broadcast_tensors(x1, x2)
43-
x1_ = torch.moveaxis(x1_, axis, -1)
44-
x2_ = torch.moveaxis(x2_, axis, -1)
45-
46-
res = x1_[..., None, :] @ x2_[..., None]
47-
return res[..., 0, 0]
48-
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
49-
50-
def solve(x1: array, x2: array, /, **kwargs) -> array:
51-
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
52-
return torch.linalg.solve(x1, x2, **kwargs)
53-
54-
# torch.trace doesn't support the offset argument and doesn't support stacking
55-
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
56-
# Use our wrapped sum to make sure it does upcasting correctly
57-
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
58-
59-
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
60-
'vecdot', 'solve']
61-
62-
del linalg_all
11+
outer = _torch.outer
12+
13+
from ._aliases import ( # noqa: E402
14+
matrix_transpose,
15+
solve,
16+
sum,
17+
tensordot,
18+
trace,
19+
vecdot_linalg as vecdot,
20+
)
21+
22+
__all__ = []
23+
24+
__all__ += _torch_linalg_all
25+
26+
__all__ += [
27+
"matrix_transpose",
28+
"solve",
29+
"sum",
30+
"outer",
31+
"trace",
32+
"tensordot",
33+
"vecdot",
34+
]

0 commit comments

Comments
 (0)