Skip to content

Commit cc7c23d

Browse files
committed
[Optimizers] Remove SE-9595 workaround and rename 'gradient' argument to 'vector'.
1 parent e7e68a0 commit cc7c23d

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

Sources/DeepLearning/Optimizer.swift

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ public protocol Optimizer {
2121
associatedtype Scalar: FloatingPoint
2222
var learningRate: Scalar { get }
2323
mutating func update(_ variables: inout Model.AllDifferentiableVariables,
24-
along gradient: Model.CotangentVector)
24+
along vector: Model.CotangentVector)
2525
}
2626

2727
// MARK: - Key-path based optimizers
2828

2929
public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
30-
where Model.AllDifferentiableVariables: AdditiveArithmetic,
31-
Model.AllDifferentiableVariables == Model.CotangentVector {
30+
where Model.AllDifferentiableVariables == Model.CotangentVector {
3231
public let learningRate: Scalar
3332
public var beta1: Scalar
3433
public var beta2: Scalar
@@ -59,7 +58,7 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
5958
private var secondMoments = Model.AllDifferentiableVariables.zero
6059

6160
public func update(_ model: inout Model.AllDifferentiableVariables,
62-
along gradient: Model.AllDifferentiableVariables) {
61+
along vector: Model.AllDifferentiableVariables) {
6362
step += 1
6463
let learningRate = self.learningRate * 1 / (1 + decay * step)
6564
let stepSize = learningRate * (sqrt(1 - pow(beta2, step)) / (1 - pow(beta1, step)))
@@ -76,8 +75,7 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
7675
}
7776

7877
public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
79-
where Model.AllDifferentiableVariables: AdditiveArithmetic,
80-
Model.AllDifferentiableVariables == Model.CotangentVector {
78+
where Model.AllDifferentiableVariables == Model.CotangentVector {
8179
public let learningRate: Scalar
8280
public let rho: Scalar
8381
public let epsilon: Scalar
@@ -103,7 +101,7 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
103101
private var alpha = Model.AllDifferentiableVariables.zero
104102

105103
public func update(_ model: inout Model.AllDifferentiableVariables,
106-
along gradient: Model.CotangentVector) {
104+
along vector: Model.CotangentVector) {
107105
step += 1
108106
let learningRate = self.learningRate * 1 / (1 + decay * step)
109107
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Scalar>.self) {
@@ -116,8 +114,7 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
116114
}
117115

118116
public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
119-
where Model.AllDifferentiableVariables: AdditiveArithmetic,
120-
Model.AllDifferentiableVariables == Model.CotangentVector {
117+
where Model.AllDifferentiableVariables == Model.CotangentVector {
121118
public let learningRate: Scalar
122119
public let momentum: Scalar
123120
public let decay: Scalar
@@ -143,7 +140,7 @@ public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
143140
private var velocity = Model.AllDifferentiableVariables.zero
144141

145142
public func update(_ model: inout Model.AllDifferentiableVariables,
146-
along gradients: Model.CotangentVector) {
143+
along vectors: Model.CotangentVector) {
147144
step += 1
148145
let learningRate = self.learningRate * 1 / (1 + decay * step)
149146
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Scalar>.self) {
@@ -170,7 +167,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
170167
}
171168

172169
public func update(_ model: inout Model.AllDifferentiableVariables,
173-
along gradient: Model.CotangentVector) {
170+
along vector: Model.CotangentVector) {
174171
model = model.moved(along: learningRate * (.zero - model.tangentVector(from: gradient)))
175172
}
176173
}

0 commit comments

Comments
 (0)