Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ff36d80

Browse files
authored
[AutoDiff] Define Tensor.zeroTangentVectorInitializer. (#1055)
`Tensor.zeroTangentVectorInitializer` now returns a zero tensor with the same shape as `self`, instead of a scalar zero tensor.
1 parent 7e7b3d3 commit ff36d80

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,11 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric {
747747

748748
extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorFlowFloatingPoint {
749749
public typealias TangentVector = Tensor
750+
751+
public var zeroTangentVectorInitializer: () -> TangentVector {
752+
let shape = self.shape
753+
return { Tensor(zeros: shape) }
754+
}
750755
}
751756

752757
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/TensorTests.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,18 @@ final class TensorTests: XCTestCase {
120120
)
121121
}
122122

123+
func testZeroTangentVectorInitializer() {
124+
let shape: TensorShape = [4, 5, 6]
125+
let tensor = Tensor<Float>(randomUniform: shape)
126+
XCTAssertEqual(tensor.zeroTangentVector, Tensor(zeros: shape))
127+
128+
struct TensorWrapper: Differentiable {
129+
var tensor: Tensor<Float>
130+
}
131+
let model = TensorWrapper(tensor: tensor)
132+
XCTAssertEqual(model.zeroTangentVector, .init(tensor: Tensor(zeros: shape)))
133+
}
134+
123135
static var allTests = [
124136
("testSimpleCond", testSimpleCond),
125137
("testRankGetter", testRankGetter),
@@ -129,5 +141,6 @@ final class TensorTests: XCTestCase {
129141
("testTensorShapeCollectionOperations", testTensorShapeCollectionOperations),
130142
("testInitShapeScalars", testInitShapeScalars),
131143
("testInitShapeScalarsDerivative", testInitShapeScalarsDerivative),
144+
("testZeroTangentVectorInitializer", testZeroTangentVectorInitializer),
132145
]
133146
}

0 commit comments

Comments
 (0)