Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 74a9bd2

Browse files
authored
Make 'Module' refine 'EuclideanDifferentiable'. (#491)
`EuclideanDifferentiable`, introduced in swiftlang/swift#26827 and swiftlang/swift#26867, provides a property, `differentiableVectorView`, to project the vector space component of a differentiable struct as a value of `TangentVector` type. This allows us to express optimization techniques that require the parameter space to be a vector space, e.g. weight decay: ```swift let 𝛁L: Model.TangentVector = ... model.move(along: -η * 𝛁L - η * λ * model.differentiableVectorView)) ``` This patch makes `Module` refine `EuclideanDifferentiable` and makes all layer combinators (e.g. `Sequential`, `RNN`) be `EuclideanDifferentiable`. All `Module`s and `Layers`s will now have property `differentiableVectorView`. Resolves #456.
1 parent b7f2a06 commit 74a9bd2

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,6 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric {
573573
// Differentiable
574574
//===------------------------------------------------------------------------------------------===//
575575

576-
extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
576+
extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorFlowFloatingPoint {
577577
public typealias TangentVector = Tensor
578578
}

Sources/TensorFlow/Layer.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
public protocol Module: Differentiable, KeyPathIterable
15+
public protocol Module: EuclideanDifferentiable, KeyPathIterable
1616
where TangentVector: VectorProtocol & ElementaryFunctions &
1717
PointwiseMultiplicative & KeyPathIterable {
1818
/// The input type of the layer.
@@ -52,9 +52,10 @@ public extension Layer {
5252
}
5353

5454
/// An empty struct representing empty `TangentVector`s for parameterless layers.
55-
public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunctions,
55+
public struct EmptyTangentVector: EuclideanDifferentiable, VectorProtocol, ElementaryFunctions,
5656
PointwiseMultiplicative, KeyPathIterable {
5757
public typealias VectorSpaceScalar = Float
58+
public typealias TangentVector = Self
5859

5960
public func adding(_ x: Float) -> EmptyTangentVector { self }
6061
public mutating func add(_ x: Float) {}
@@ -67,13 +68,14 @@ public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunc
6768
/// A parameterless neural network layer.
6869
///
6970
/// The `TangentVector` of parameterless layers is always `EmptyTangentVector`.
70-
public protocol ParameterlessLayer: Layer {
71+
public protocol ParameterlessLayer: Layer where TangentVector == EmptyTangentVector {
7172
@differentiable
7273
func callAsFunction(_ input: Input) -> Output
7374
}
7475

7576
public extension ParameterlessLayer {
7677
mutating func move(along direction: EmptyTangentVector) {}
78+
var differentiableVectorView: EmptyTangentVector { EmptyTangentVector() }
7779
}
7880

7981
public extension Layer {

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ public struct RNNCellInput<Input: Differentiable, State: Differentiable>: Differ
2626
}
2727
}
2828

29+
extension RNNCellInput: EuclideanDifferentiable
30+
where Input: EuclideanDifferentiable, State: EuclideanDifferentiable {}
31+
2932
/// An output to a recurrent neural network.
3033
public struct RNNCellOutput<Output: Differentiable, State: Differentiable>: Differentiable {
3134
/// The output at the current time step.
@@ -40,6 +43,9 @@ public struct RNNCellOutput<Output: Differentiable, State: Differentiable>: Diff
4043
}
4144
}
4245

46+
extension RNNCellOutput: EuclideanDifferentiable
47+
where Output: EuclideanDifferentiable, State: EuclideanDifferentiable {}
48+
4349
/// A recurrent neural network cell.
4450
public protocol RNNCell: Layer
4551
where Input == RNNCellInput<TimeStepInput, State>,

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ final class LayerTests: XCTestCase {
243243
XCTAssertEqual(output, expected)
244244
}
245245

246-
247246
func testZeroPadding1D() {
248247
let input = Tensor<Float>([0.0, 1.0, 2.0])
249248
let layer = ZeroPadding1D<Float>(padding: 2)

0 commit comments

Comments
 (0)