Skip to content

Commit 50e273d

Browse files
committed
Fix some minor signature issues in the torch wrappers
1 parent 48cc745 commit 50e273d

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
141141
return torch.can_cast(from_, to)
142142

143143
# Basic renames
144-
permute_dims = torch.permute
145144
bitwise_invert = torch.bitwise_not
146145

147146
# Two-arg elementwise functions
@@ -451,18 +450,26 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
451450
x = torch.squeeze(x, a)
452451
return x
453452

453+
# torch.broadcast_to uses size instead of shape
454+
def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
455+
return torch.broadcast_to(x, shape, **kwargs)
456+
457+
# torch.permute uses dims instead of axes
458+
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
459+
return torch.permute(x, axes)
460+
454461
# The axis parameter doesn't work for flip() and roll()
455462
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
456463
# accept axis=None
457-
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
464+
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
458465
if axis is None:
459466
axis = tuple(range(x.ndim))
460467
# torch.flip doesn't accept dim as an int but the method does
461468
# https://github.com/pytorch/pytorch/issues/18095
462-
return x.flip(axis)
469+
return x.flip(axis, **kwargs)
463470

464-
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
465-
return torch.roll(x, shift, axis)
471+
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
472+
return torch.roll(x, shift, axis, **kwargs)
466473

467474
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
468475
return torch.nonzero(x, as_tuple=True, **kwargs)
@@ -673,9 +680,9 @@ def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
673680
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
674681
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
675682
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
676-
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
677-
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
678-
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
683+
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip',
684+
'roll', 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
685+
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
679686
'broadcast_arrays', 'unique_all', 'unique_counts',
680687
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
681688
'vecdot', 'tensordot', 'isdtype', 'take']

0 commit comments

Comments
 (0)