@@ -543,6 +543,14 @@ def empty(shape: Union[int, Tuple[int, ...]],
543
543
** kwargs ) -> array :
544
544
return torch .empty (shape , dtype = dtype , device = device , ** kwargs )
545
545
546
+ # tril and triu do not call the keyword argument k
547
+
548
+ def tril (x : array , / , * , k : int = 0 ) -> array :
549
+ return torch .tril (x , k )
550
+
551
+ def triu (x : array , / , * , k : int = 0 ) -> array :
552
+ return torch .triu (x , k )
553
+
546
554
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
547
555
def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
548
556
return torch .unsqueeze (x , axis )
@@ -610,6 +618,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
610
618
'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
611
619
'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip' , 'roll' ,
612
620
'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
613
- 'zeros' , 'empty' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
614
- 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
615
- 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
621
+ 'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
622
+ 'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
623
+ 'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
624
+ 'vecdot' , 'tensordot' ]
0 commit comments