Skip to content

Commit 59f08b2

Browse files
committed
Move linalg aliases to torch/_aliases
1 parent c2a4f5f commit 59f08b2

File tree

1 file changed

+50
-21
lines changed

1 file changed

+50
-21
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
from __future__ import annotations
22

3+
from builtins import all as builtin_all
4+
from builtins import any as builtin_any
35
from functools import wraps
4-
from builtins import all as builtin_all, any as builtin_any
5-
6-
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
7-
UniqueInverseResult,
8-
matrix_transpose as _aliases_matrix_transpose,
9-
vecdot as _aliases_vecdot)
10-
from .._internal import get_xp
6+
from typing import TYPE_CHECKING
117

128
import torch
139

14-
from typing import TYPE_CHECKING
10+
from .._internal import get_xp
11+
from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
12+
from ..common._aliases import matrix_transpose as _aliases_matrix_transpose
13+
from ..common._aliases import vecdot as _aliases_vecdot
14+
1515
if TYPE_CHECKING:
1616
from typing import List, Optional, Sequence, Tuple, Union
17-
from ..common._typing import Device
17+
1818
from torch import dtype as Dtype
1919

20+
from ..common._typing import Device
21+
2022
array = torch.Tensor
2123

2224
_int_dtypes = {
@@ -693,15 +695,42 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
693695
axis = 0
694696
return torch.index_select(x, axis, indices, **kwargs)
695697

696-
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis',
697-
'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
698-
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
699-
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
700-
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
701-
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
702-
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
703-
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
704-
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
705-
'broadcast_arrays', 'unique_all', 'unique_counts',
706-
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
707-
'vecdot', 'tensordot', 'isdtype', 'take']
698+
699+
700+
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
701+
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
702+
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
703+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
704+
return torch.linalg.cross(x1, x2, dim=axis)
705+
706+
def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
707+
from ._aliases import isdtype
708+
709+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
710+
711+
# torch.linalg.vecdot doesn't support integer dtypes
712+
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
713+
if kwargs:
714+
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
715+
ndim = max(x1.ndim, x2.ndim)
716+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
717+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
718+
if x1_shape[axis] != x2_shape[axis]:
719+
raise ValueError("x1 and x2 must have the same size along the given axis")
720+
721+
x1_, x2_ = torch.broadcast_tensors(x1, x2)
722+
x1_ = torch.moveaxis(x1_, axis, -1)
723+
x2_ = torch.moveaxis(x2_, axis, -1)
724+
725+
res = x1_[..., None, :] @ x2_[..., None]
726+
return res[..., 0, 0]
727+
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
728+
729+
def solve(x1: array, x2: array, /, **kwargs) -> array:
730+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
731+
return torch.linalg.solve(x1, x2, **kwargs)
732+
733+
# torch.trace doesn't support the offset argument and doesn't support stacking
734+
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
735+
# Use our wrapped sum to make sure it does upcasting correctly
736+
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

0 commit comments

Comments
 (0)