@@ -141,7 +141,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
141
141
return torch .can_cast (from_ , to )
142
142
143
143
# Basic renames
144
- permute_dims = torch .permute
145
144
bitwise_invert = torch .bitwise_not
146
145
147
146
# Two-arg elementwise functions
@@ -451,18 +450,26 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
451
450
x = torch .squeeze (x , a )
452
451
return x
453
452
453
+ # torch.broadcast_to uses size instead of shape
454
+ def broadcast_to (x : array , / , shape : Tuple [int , ...], ** kwargs ) -> array :
455
+ return torch .broadcast_to (x , shape , ** kwargs )
456
+
457
+ # torch.permute uses dims instead of axes
458
+ def permute_dims (x : array , / , axes : Tuple [int , ...]) -> array :
459
+ return torch .permute (x , axes )
460
+
454
461
# The axis parameter doesn't work for flip() and roll()
455
462
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
456
463
# accept axis=None
457
- def flip (x : array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None ) -> array :
464
+ def flip (x : array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None , ** kwargs ) -> array :
458
465
if axis is None :
459
466
axis = tuple (range (x .ndim ))
460
467
# torch.flip doesn't accept dim as an int but the method does
461
468
# https://github.com/pytorch/pytorch/issues/18095
462
- return x .flip (axis )
469
+ return x .flip (axis , ** kwargs )
463
470
464
- def roll (x : array , / , shift : Union [int , Tuple [int , ...]], * , axis : Optional [Union [int , Tuple [int , ...]]] = None ) -> array :
465
- return torch .roll (x , shift , axis )
471
+ def roll (x : array , / , shift : Union [int , Tuple [int , ...]], * , axis : Optional [Union [int , Tuple [int , ...]]] = None , ** kwargs ) -> array :
472
+ return torch .roll (x , shift , axis , ** kwargs )
466
473
467
474
def nonzero (x : array , / , ** kwargs ) -> Tuple [array , ...]:
468
475
return torch .nonzero (x , as_tuple = True , ** kwargs )
@@ -673,9 +680,9 @@ def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
673
680
'floor_divide' , 'greater' , 'greater_equal' , 'less' , 'less_equal' ,
674
681
'logaddexp' , 'multiply' , 'not_equal' , 'pow' , 'remainder' ,
675
682
'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
676
- 'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip ' , 'roll ' ,
677
- 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones ' ,
678
- 'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
683
+ 'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'broadcast_to ' , 'flip ' ,
684
+ 'roll' , ' nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' ,
685
+ 'ones' , ' zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
679
686
'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
680
687
'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
681
688
'vecdot' , 'tensordot' , 'isdtype' , 'take' ]
0 commit comments