Skip to content

Commit 8c872f0

Browse files
committed
Update torch.take and torch.result_type for the 2022 spec
1 parent 9ef7f72 commit 8c872f0

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

array_api_compat/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919

2020
from ..common._helpers import *
2121

22-
__array_api_version__ = '2021.12'
22+
__array_api_version__ = '2022.12'

array_api_compat/torch/_aliases.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
*_int_dtypes,
3333
torch.float32,
3434
torch.float64,
35+
torch.complex64,
36+
torch.complex128,
3537
}
3638

3739
_promotion_table = {
@@ -70,6 +72,16 @@
7072
(torch.float32, torch.float64): torch.float64,
7173
(torch.float64, torch.float32): torch.float64,
7274
(torch.float64, torch.float64): torch.float64,
75+
# complexes
76+
(torch.complex64, torch.complex64): torch.complex64,
77+
(torch.complex64, torch.complex128): torch.complex128,
78+
(torch.complex128, torch.complex64): torch.complex128,
79+
(torch.complex128, torch.complex128): torch.complex128,
80+
# Mixed float and complex
81+
(torch.float32, torch.complex64): torch.complex64,
82+
(torch.float32, torch.complex128): torch.complex128,
83+
(torch.float64, torch.complex64): torch.complex128,
84+
(torch.float64, torch.complex128): torch.complex128,
7385
}
7486

7587

@@ -652,6 +664,9 @@ def isdtype(
652664
else:
653665
return dtype == kind
654666

667+
def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
668+
return torch.index_select(x, axis, indices, **kwargs)
669+
655670
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
656671
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
657672
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
@@ -663,4 +678,4 @@ def isdtype(
663678
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
664679
'broadcast_arrays', 'unique_all', 'unique_counts',
665680
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
666-
'vecdot', 'tensordot', 'isdtype']
681+
'vecdot', 'tensordot', 'isdtype', 'take']

0 commit comments

Comments
 (0)