Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5caa460

Browse files
authored
TF-425: Make learning rates dynamically settable. (#81)
Modern training regiemes require the ability to change the learning rate during training according to a particular schedule. By making the Optimizer protocol allow the learning rate to be scheduled, this makes it easy to implement based on callbacks.
1 parent f4cd012 commit 5caa460

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

Sources/DeepLearning/Optimizer.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public protocol Optimizer {
2626
/// The scalar parameter type.
2727
associatedtype Scalar: FloatingPoint
2828
/// The learning rate.
29-
var learningRate: Scalar { get }
29+
var learningRate: Scalar { get set }
3030
/// Updates the specified differentiable variables along the specified
3131
/// direction.
3232
mutating func update(_ variables: inout Model.AllDifferentiableVariables,
@@ -42,17 +42,17 @@ public protocol Optimizer {
4242
public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
4343
where Model.AllDifferentiableVariables == Model.CotangentVector {
4444
/// The learning rate.
45-
public let learningRate: Scalar
45+
public var learningRate: Scalar
4646
/// A coefficient used to calculate the first and second moments of
4747
/// gradients.
4848
public var beta1: Scalar
4949
/// A coefficient used to calculate the first and second moments of
5050
/// gradients.
5151
public var beta2: Scalar
5252
/// A small scalar added to the denominator to improve numerical stability.
53-
public let epsilon: Scalar
53+
public var epsilon: Scalar
5454
/// The weight decay.
55-
public let decay: Scalar
55+
public var decay: Scalar
5656

5757
public init(
5858
learningRate: Scalar = 1e-3,
@@ -122,13 +122,13 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
122122
public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
123123
where Model.AllDifferentiableVariables == Model.CotangentVector {
124124
/// The learning rate.
125-
public let learningRate: Scalar
125+
public var learningRate: Scalar
126126
// TODO: Document `rho`. Keras doesn't document `rho`.
127-
public let rho: Scalar
127+
public var rho: Scalar
128128
/// A small scalar added to the denominator to improve numerical stability.
129-
public let epsilon: Scalar
129+
public var epsilon: Scalar
130130
/// The weight decay.
131-
public let decay: Scalar
131+
public var decay: Scalar
132132

133133
public init(
134134
learningRate: Scalar = 0.001,
@@ -180,14 +180,14 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
180180
public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
181181
where Model.AllDifferentiableVariables == Model.CotangentVector {
182182
/// The learning rate.
183-
public let learningRate: Scalar
183+
public var learningRate: Scalar
184184
/// The momentum factor. It accelerates stochastic gradient descent in the relevant direction
185185
/// and dampens oscillations.
186-
public let momentum: Scalar
186+
public var momentum: Scalar
187187
/// The weight decay.
188-
public let decay: Scalar
188+
public var decay: Scalar
189189
/// Use Nesterov momentum if true.
190-
public let nesterov: Bool
190+
public var nesterov: Bool
191191

192192
public init(
193193
learningRate: Scalar = 0.01,

0 commit comments

Comments
 (0)