1
1
from __future__ import annotations
2
2
3
+ from builtins import all as builtin_all
4
+ from builtins import any as builtin_any
3
5
from functools import wraps
4
- from builtins import all as builtin_all , any as builtin_any
5
-
6
- from ..common ._aliases import (UniqueAllResult , UniqueCountsResult ,
7
- UniqueInverseResult ,
8
- matrix_transpose as _aliases_matrix_transpose ,
9
- vecdot as _aliases_vecdot )
10
- from .._internal import get_xp
6
+ from typing import TYPE_CHECKING
11
7
12
8
import torch
13
9
14
- from typing import TYPE_CHECKING
10
+ from .._internal import get_xp
11
+ from ..common ._aliases import UniqueAllResult , UniqueCountsResult , UniqueInverseResult
12
+ from ..common ._aliases import matrix_transpose as _aliases_matrix_transpose
13
+ from ..common ._aliases import vecdot as _aliases_vecdot
14
+
15
15
if TYPE_CHECKING :
16
16
from typing import List , Optional , Sequence , Tuple , Union
17
- from .. common . _typing import Device
17
+
18
18
from torch import dtype as Dtype
19
19
20
+ from ..common ._typing import Device
21
+
20
22
array = torch .Tensor
21
23
22
24
_int_dtypes = {
@@ -693,15 +695,42 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
693
695
axis = 0
694
696
return torch .index_select (x , axis , indices , ** kwargs )
695
697
696
- __all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'newaxis' ,
697
- 'add' , 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
698
- 'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
699
- 'floor_divide' , 'greater' , 'greater_equal' , 'less' , 'less_equal' ,
700
- 'logaddexp' , 'multiply' , 'not_equal' , 'pow' , 'remainder' ,
701
- 'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
702
- 'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'broadcast_to' , 'flip' , 'roll' ,
703
- 'nonzero' , 'where' , 'reshape' , 'arange' , 'eye' , 'linspace' , 'full' ,
704
- 'ones' , 'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
705
- 'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
706
- 'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
707
- 'vecdot' , 'tensordot' , 'isdtype' , 'take' ]
698
+
699
+
700
+ # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
701
+ # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
702
+ def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
703
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
704
+ return torch .linalg .cross (x1 , x2 , dim = axis )
705
+
706
+ def vecdot_linalg (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
707
+ from ._aliases import isdtype
708
+
709
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
710
+
711
+ # torch.linalg.vecdot doesn't support integer dtypes
712
+ if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
713
+ if kwargs :
714
+ raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
715
+ ndim = max (x1 .ndim , x2 .ndim )
716
+ x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
717
+ x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
718
+ if x1_shape [axis ] != x2_shape [axis ]:
719
+ raise ValueError ("x1 and x2 must have the same size along the given axis" )
720
+
721
+ x1_ , x2_ = torch .broadcast_tensors (x1 , x2 )
722
+ x1_ = torch .moveaxis (x1_ , axis , - 1 )
723
+ x2_ = torch .moveaxis (x2_ , axis , - 1 )
724
+
725
+ res = x1_ [..., None , :] @ x2_ [..., None ]
726
+ return res [..., 0 , 0 ]
727
+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
728
+
729
+ def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
730
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
731
+ return torch .linalg .solve (x1 , x2 , ** kwargs )
732
+
733
+ # torch.trace doesn't support the offset argument and doesn't support stacking
734
+ def trace (x : array , / , * , offset : int = 0 , dtype : Optional [Dtype ] = None ) -> array :
735
+ # Use our wrapped sum to make sure it does upcasting correctly
736
+ return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
0 commit comments