This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +18
-0
lines changed Expand file tree Collapse file tree 2 files changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -747,6 +747,11 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric {
747
747
748
748
extension Tensor : Differentiable & EuclideanDifferentiable where Scalar: TensorFlowFloatingPoint {
749
749
public typealias TangentVector = Tensor
750
+
751
+ public var zeroTangentVectorInitializer : ( ) -> TangentVector {
752
+ let shape = self . shape
753
+ return { Tensor ( zeros: shape) }
754
+ }
750
755
}
751
756
752
757
//===------------------------------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -120,6 +120,18 @@ final class TensorTests: XCTestCase {
120
120
)
121
121
}
122
122
123
+ func testZeroTangentVectorInitializer( ) {
124
+ let shape : TensorShape = [ 4 , 5 , 6 ]
125
+ let tensor = Tensor < Float > ( randomUniform: shape)
126
+ XCTAssertEqual ( tensor. zeroTangentVector, Tensor ( zeros: shape) )
127
+
128
+ struct TensorWrapper : Differentiable {
129
+ var tensor : Tensor < Float >
130
+ }
131
+ let model = TensorWrapper ( tensor: tensor)
132
+ XCTAssertEqual ( model. zeroTangentVector, . init( tensor: Tensor ( zeros: shape) ) )
133
+ }
134
+
123
135
static var allTests = [
124
136
( " testSimpleCond " , testSimpleCond) ,
125
137
( " testRankGetter " , testRankGetter) ,
@@ -129,5 +141,6 @@ final class TensorTests: XCTestCase {
129
141
( " testTensorShapeCollectionOperations " , testTensorShapeCollectionOperations) ,
130
142
( " testInitShapeScalars " , testInitShapeScalars) ,
131
143
( " testInitShapeScalarsDerivative " , testInitShapeScalarsDerivative) ,
144
+ ( " testZeroTangentVectorInitializer " , testZeroTangentVectorInitializer) ,
132
145
]
133
146
}
You can’t perform that action at this time.
0 commit comments