32
32
* _int_dtypes ,
33
33
torch .float32 ,
34
34
torch .float64 ,
35
+ torch .complex64 ,
36
+ torch .complex128 ,
35
37
}
36
38
37
39
_promotion_table = {
70
72
(torch .float32 , torch .float64 ): torch .float64 ,
71
73
(torch .float64 , torch .float32 ): torch .float64 ,
72
74
(torch .float64 , torch .float64 ): torch .float64 ,
75
+ # complexes
76
+ (torch .complex64 , torch .complex64 ): torch .complex64 ,
77
+ (torch .complex64 , torch .complex128 ): torch .complex128 ,
78
+ (torch .complex128 , torch .complex64 ): torch .complex128 ,
79
+ (torch .complex128 , torch .complex128 ): torch .complex128 ,
80
+ # Mixed float and complex
81
+ (torch .float32 , torch .complex64 ): torch .complex64 ,
82
+ (torch .float32 , torch .complex128 ): torch .complex128 ,
83
+ (torch .float64 , torch .complex64 ): torch .complex128 ,
84
+ (torch .float64 , torch .complex128 ): torch .complex128 ,
73
85
}
74
86
75
87
@@ -652,6 +664,9 @@ def isdtype(
652
664
else :
653
665
return dtype == kind
654
666
667
+ def take (x : array , indices : array , / , * , axis : int , ** kwargs ) -> array :
668
+ return torch .index_select (x , axis , indices , ** kwargs )
669
+
655
670
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
656
671
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
657
672
'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -663,4 +678,4 @@ def isdtype(
663
678
'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
664
679
'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
665
680
'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
666
- 'vecdot' , 'tensordot' , 'isdtype' ]
681
+ 'vecdot' , 'tensordot' , 'isdtype' , 'take' ]
0 commit comments