Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 4bab884

Browse files
authored
Add subdiagonalCount:superdiagonalCount: argument labels to Tensor.bandPart. (#588)
Argument labels improve clarity. Deprecate label-less version, which has one user in tensorflow/swift-models.
1 parent c84d4c0 commit 4bab884

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

Sources/TensorFlow/Operators/LinearAlgebra.swift

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ public extension Tensor where Scalar: TensorFlowNumeric {
5656
_Raw.matrixDiag(diagonal: self)
5757
}
5858

59+
@available(*, deprecated, renamed: "bandPart(subdiagonalCount:superdiagonalCount:)")
60+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
61+
func bandPart(_ subdiagonalCount: Int, _ superdiagonalCount: Int) -> Tensor {
62+
return bandPart(subdiagonalCount: subdiagonalCount, superdiagonalCount: superdiagonalCount)
63+
}
64+
5965
/// Returns a copy of a innermost tensor defined by a central band boundaries.
6066
/// The output is a tensor of the same shape as the instance `[..., :, :]`.
6167
///
@@ -79,12 +85,18 @@ public extension Tensor where Scalar: TensorFlowNumeric {
7985
/// // [-2, -1, 0, 1]
8086
/// // [ 0, -2, -1, 0]]
8187
/// ```
88+
///
89+
/// - Parameters:
90+
/// - subdiagonalCount: The number of subdiagonals to keep. If negative, keep entire lower
91+
/// triangle.
92+
/// - superdiagonalCount: The number of superdiagonals to keep. If negative, keep entire upper
93+
/// triangle.
8294
@inlinable
8395
@differentiable(wrt: self, vjp: _vjpBandPart where Scalar: TensorFlowFloatingPoint)
84-
func bandPart(_ lowerCount: Int, _ upperCount: Int) -> Tensor {
96+
func bandPart(subdiagonalCount: Int, superdiagonalCount: Int) -> Tensor {
8597
precondition(rank >= 2, "The tensor must have at least rank 2.")
86-
let lower = Tensor<Int32>(Int32(lowerCount))
87-
let upper = Tensor<Int32>(Int32(upperCount))
98+
let lower = Tensor<Int32>(Int32(subdiagonalCount))
99+
let upper = Tensor<Int32>(Int32(superdiagonalCount))
88100
return _Raw.matrixBandPart(self, numLower: lower, numUpper: upper)
89101
}
90102
}
@@ -101,8 +113,15 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
101113
}
102114

103115
@inlinable
104-
func _vjpBandPart(_ numLower: Int, _ numUpper: Int) -> (Tensor, (Tensor) -> Tensor) {
105-
(bandPart(numLower, numUpper), { $0.bandPart(numLower, numUpper) })
116+
func _vjpBandPart(
117+
subdiagonalCount: Int, superdiagonalCount: Int
118+
) -> (Tensor, (Tensor) -> Tensor) {
119+
let value = bandPart(
120+
subdiagonalCount: subdiagonalCount,
121+
superdiagonalCount: superdiagonalCount)
122+
return (value, {
123+
$0.bandPart(subdiagonalCount: subdiagonalCount, superdiagonalCount: superdiagonalCount)
124+
})
106125
}
107126
}
108127

Tests/TensorFlowTests/OperatorTests/MatrixTests.swift

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,17 @@ final class MatrixTests: XCTestCase {
8080
[-1, 0, 1, 2],
8181
[ 0, -1, 0, 1],
8282
[ 0, 0, -1, 0]])
83-
XCTAssertEqual(t1.bandPart(1, -1), target1)
83+
XCTAssertEqual(t1.bandPart(subdiagonalCount: 1, superdiagonalCount: -1), target1)
8484

8585
let target2 = Tensor<Float>([[ 0, 1, 0, 0],
8686
[-1, 0, 1, 0],
8787
[-2, -1, 0, 1],
8888
[ 0, -2, -1, 0]])
89-
XCTAssertEqual(t1.bandPart(2, 1), target2)
89+
XCTAssertEqual(t1.bandPart(subdiagonalCount: 2, superdiagonalCount: 1), target2)
9090

9191
// Test special case - diagonal
92-
XCTAssertEqual(t1.bandPart(0, 0), Tensor<Float>(zeros: [4, 4]))
92+
XCTAssertEqual(t1.bandPart(subdiagonalCount: 0, superdiagonalCount: 0),
93+
Tensor<Float>(zeros: [4, 4]))
9394

9495
// Test leading dimensions with special case - lower triangular
9596
let t2 = Tensor<Float>(stacking: [t1, t1 + 1])
@@ -101,12 +102,14 @@ final class MatrixTests: XCTestCase {
101102
[ 0, 1, 0, 0],
102103
[-1, 0, 1, 0],
103104
[-2, -1, 0, 1]]])
104-
XCTAssertEqual(t2.bandPart(-1, 0), target3)
105+
XCTAssertEqual(t2.bandPart(subdiagonalCount: -1, superdiagonalCount: 0), target3)
105106

106107
// Test bandPart gradient with special case - upper triangular
107108
let t3 = Tensor<Float>(shape: [2, 4, 4], scalars: (1...(2 * 16)).map(Float.init))
108-
let computedGrad = gradient(at: t3) { $0.squared().bandPart(0, -1).sum() }
109-
let expectedGrad = 2 * t3.bandPart(0, -1)
109+
let computedGrad = gradient(at: t3) {
110+
$0.squared().bandPart(subdiagonalCount: 0, superdiagonalCount: -1).sum()
111+
}
112+
let expectedGrad = 2 * t3.bandPart(subdiagonalCount: 0, superdiagonalCount: -1)
110113
XCTAssertEqual(computedGrad, expectedGrad)
111114
}
112115

0 commit comments

Comments
 (0)