Skip to content

Commit c55eb43

Browse files
committed
Fix the tril and triu functions in torch
They did not support the k keyword argument.
1 parent 0cf6967 commit c55eb43

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,14 @@ def empty(shape: Union[int, Tuple[int, ...]],
543543
**kwargs) -> array:
544544
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
545545

546+
# tril and triu do not call the keyword argument k
547+
548+
def tril(x: array, /, *, k: int = 0) -> array:
549+
return torch.tril(x, k)
550+
551+
def triu(x: array, /, *, k: int = 0) -> array:
552+
return torch.triu(x, k)
553+
546554
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
547555
def expand_dims(x: array, /, *, axis: int = 0) -> array:
548556
return torch.unsqueeze(x, axis)
@@ -610,6 +618,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
610618
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
611619
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
612620
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
613-
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
614-
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
615-
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']
621+
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
622+
'broadcast_arrays', 'unique_all', 'unique_counts',
623+
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
624+
'vecdot', 'tensordot']

0 commit comments

Comments
 (0)