@@ -56,6 +56,12 @@ public extension Tensor where Scalar: TensorFlowNumeric {
56
56
_Raw. matrixDiag ( diagonal: self )
57
57
}
58
58
59
+ @available ( * , deprecated, renamed: " bandPart(subdiagonalCount:superdiagonalCount:) " )
60
+ @differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
61
+ func bandPart( _ subdiagonalCount: Int , _ superdiagonalCount: Int ) -> Tensor {
62
+ return bandPart ( subdiagonalCount: subdiagonalCount, superdiagonalCount: superdiagonalCount)
63
+ }
64
+
59
65
/// Returns a copy of a innermost tensor defined by a central band boundaries.
60
66
/// The output is a tensor of the same shape as the instance `[..., :, :]`.
61
67
///
@@ -79,12 +85,18 @@ public extension Tensor where Scalar: TensorFlowNumeric {
79
85
/// // [-2, -1, 0, 1]
80
86
/// // [ 0, -2, -1, 0]]
81
87
/// ```
88
+ ///
89
+ /// - Parameters:
90
+ /// - subdiagonalCount: The number of subdiagonals to keep. If negative, keep entire lower
91
+ /// triangle.
92
+ /// - superdiagonalCount: The number of superdiagonals to keep. If negative, keep entire upper
93
+ /// triangle.
82
94
@inlinable
83
95
@differentiable ( wrt: self , vjp: _vjpBandPart where Scalar: TensorFlowFloatingPoint)
84
- func bandPart( _ lowerCount : Int , _ upperCount : Int ) -> Tensor {
96
+ func bandPart( subdiagonalCount : Int , superdiagonalCount : Int ) -> Tensor {
85
97
precondition ( rank >= 2 , " The tensor must have at least rank 2. " )
86
- let lower = Tensor < Int32 > ( Int32 ( lowerCount ) )
87
- let upper = Tensor < Int32 > ( Int32 ( upperCount ) )
98
+ let lower = Tensor < Int32 > ( Int32 ( subdiagonalCount ) )
99
+ let upper = Tensor < Int32 > ( Int32 ( superdiagonalCount ) )
88
100
return _Raw. matrixBandPart ( self , numLower: lower, numUpper: upper)
89
101
}
90
102
}
@@ -101,8 +113,15 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
101
113
}
102
114
103
115
@inlinable
104
- func _vjpBandPart( _ numLower: Int , _ numUpper: Int ) -> ( Tensor , ( Tensor ) -> Tensor ) {
105
- ( bandPart ( numLower, numUpper) , { $0. bandPart ( numLower, numUpper) } )
116
+ func _vjpBandPart(
117
+ subdiagonalCount: Int , superdiagonalCount: Int
118
+ ) -> ( Tensor , ( Tensor ) -> Tensor ) {
119
+ let value = bandPart (
120
+ subdiagonalCount: subdiagonalCount,
121
+ superdiagonalCount: superdiagonalCount)
122
+ return ( value, {
123
+ $0. bandPart ( subdiagonalCount: subdiagonalCount, superdiagonalCount: superdiagonalCount)
124
+ } )
106
125
}
107
126
}
108
127
0 commit comments