Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 12d7030

Browse files
authored
[Optimizer] Simplify optimizers using generalized vector math. (#218)
This is the start of a series of PRs that make optimizers no longer depend on `KeyPathIterable` or require `AllDifferentiableVariables` to equal `TangentVector`. This is possible because we recently overhauled generalized vector math. Changes include: * Define an extension for `VectorProtocol` that defines arithmetic operators in terms of `adding(_:)`, `subtracting(_:)`, and `scaled(by:)`. * Change `SGD.update(_:along:)` to use vector math. * Make `Optimizer.update(_:along:)` take `inout Model` instead of `inout Model.AllDifferentiableVariables`. This makes it easy to deprecate `AllDifferentiableVariables` later. * Add a `update(_:along:)` that takes `inout Model` to conform to the protocol without removing the implementation. This is for short-term source compatibility.
1 parent a2e3f9e commit 12d7030

File tree

5 files changed

+104
-58
lines changed

5 files changed

+104
-58
lines changed

Sources/TensorFlow/Loss.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
231231
// This numerical stable implementation is based on tf.nn.sigmoid_cross_entropy_with_logits.
232232

233233
let maxLogitsWithZero = max(logits, Tensor(0))
234-
let loss = maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits)))
234+
var loss = maxLogitsWithZero - logits * labels
235+
loss = loss + log(1 + exp(-abs(logits)))
235236
return loss.mean()
236237
}

Sources/TensorFlow/Operators/Math.swift

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,22 +161,54 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
161161
// Vector Space
162162
//===------------------------------------------------------------------------------------------===//
163163

164-
extension Tensor: VectorProtocol where Scalar: Numeric {
165-
public typealias VectorSpaceScalar = Scalar
164+
extension Tensor: VectorProtocol where Scalar: TensorFlowFloatingPoint {
165+
public typealias VectorSpaceScalar = Float
166166

167-
@differentiable(where Scalar: TensorFlowFloatingPoint)
168-
public func adding(_ scalar: Scalar) -> Self {
169-
self + scalar
167+
// @differentiable(where Scalar: TensorFlowFloatingPoint)
168+
public func scaled(by scale: Float) -> Self {
169+
Scalar(scale) * self
170170
}
171171

172-
@differentiable(where Scalar: TensorFlowFloatingPoint)
173-
public func subtracting(_ scalar: Scalar) -> Self {
174-
self - scalar
172+
// @differentiable(where Scalar: TensorFlowFloatingPoint)
173+
public func adding(_ scalar: Float) -> Self {
174+
self + Scalar(scalar)
175175
}
176176

177-
@differentiable(where Scalar: TensorFlowFloatingPoint)
178-
public func scaled(by scalar: Scalar) -> Self {
179-
self * scalar
177+
// @differentiable(where Scalar: TensorFlowFloatingPoint)
178+
public func subtracting(_ scalar: Float) -> Self {
179+
self - Scalar(scalar)
180+
}
181+
}
182+
183+
extension VectorProtocol {
184+
static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self {
185+
rhs.adding(lhs)
186+
}
187+
188+
static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self {
189+
lhs.adding(rhs)
190+
}
191+
192+
static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self {
193+
lhs.subtracting(rhs)
194+
}
195+
196+
static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self {
197+
rhs.scaled(by: lhs)
198+
}
199+
200+
static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
201+
lhs.scaled(by: rhs)
202+
}
203+
}
204+
205+
extension VectorProtocol where VectorSpaceScalar: SignedNumeric {
206+
static prefix func - (x: Self) -> Self {
207+
.zero - x
208+
}
209+
210+
static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self {
211+
(-rhs).adding(lhs)
180212
}
181213
}
182214

Sources/TensorFlow/Optimizer.swift

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ public protocol Optimizer {
2525
var learningRate: Scalar { get set }
2626
/// Updates the specified differentiable variables along the specified
2727
/// direction.
28-
mutating func update(_ variables: inout Model.AllDifferentiableVariables,
29-
along direction: Model.TangentVector)
28+
mutating func update(_ variables: inout Model, along direction: Model.TangentVector)
3029
}
3130

3231
fileprivate extension Tensor where Scalar: Numeric {
@@ -35,14 +34,13 @@ fileprivate extension Tensor where Scalar: Numeric {
3534
}
3635
}
3736

38-
// MARK: - Key-path based optimizers
39-
4037
/// Adam optimizer.
4138
///
4239
/// Reference: ["Adam - A Method for Stochastic Optimization"](
4340
/// https://arxiv.org/abs/1412.6980v8)
4441
public class Adam<Model: Layer>: Optimizer
4542
where Model.AllDifferentiableVariables == Model.TangentVector {
43+
public typealias Model = Model
4644
/// The learning rate.
4745
public var learningRate: Float
4846
/// A coefficient used to calculate the first and second moments of
@@ -96,7 +94,7 @@ public class Adam<Model: Layer>: Optimizer
9694
}
9795
}
9896

99-
97+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
10098
public func update(_ model: inout Model.AllDifferentiableVariables,
10199
along direction: Model.AllDifferentiableVariables) {
102100
step += 1
@@ -127,6 +125,11 @@ public class Adam<Model: Layer>: Optimizer
127125
sqrt(secondMoments[keyPath: kp]) + Double(epsilon)
128126
}
129127
}
128+
129+
public func update(_ model: inout Model,
130+
along direction: Model.TangentVector) {
131+
update(&model.allDifferentiableVariables, along: direction)
132+
}
130133
}
131134

132135
/// RMSProp optimizer.
@@ -139,6 +142,7 @@ public class Adam<Model: Layer>: Optimizer
139142
/// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
140143
public class RMSProp<Model: Layer>: Optimizer
141144
where Model.AllDifferentiableVariables == Model.TangentVector {
145+
public typealias Model = Model
142146
/// The learning rate.
143147
public var learningRate: Float
144148
// TODO: Document `rho`. Keras doesn't document `rho`.
@@ -176,7 +180,7 @@ public class RMSProp<Model: Layer>: Optimizer
176180
}
177181
}
178182

179-
183+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
180184
public func update(_ model: inout Model.AllDifferentiableVariables,
181185
along direction: Model.TangentVector) {
182186
step += 1
@@ -195,14 +199,21 @@ public class RMSProp<Model: Layer>: Optimizer
195199
(sqrt(alpha[keyPath: kp]) + Double(epsilon))
196200
}
197201
}
202+
203+
public func update(_ model: inout Model,
204+
along direction: Model.TangentVector) {
205+
update(&model.allDifferentiableVariables, along: direction)
206+
}
198207
}
199208

200209
/// Stochastic gradient descent (SGD) optimizer.
201210
///
202211
/// An optimizer that implements stochastic gradient descent, with support for momentum, learning
203212
/// rate decay, and Nesterov momentum.
204-
public class SGD<Model: Layer>: Optimizer
205-
where Model.AllDifferentiableVariables == Model.TangentVector {
213+
public class SGD<Model: Differentiable>: Optimizer
214+
where Model.TangentVector: VectorProtocol & ElementaryFunctions,
215+
Model.TangentVector.VectorSpaceScalar == Float {
216+
public typealias Model = Model
206217
/// The learning rate.
207218
public var learningRate: Float
208219
/// The momentum factor. It accelerates stochastic gradient descent in the relevant direction
@@ -212,8 +223,8 @@ public class SGD<Model: Layer>: Optimizer
212223
public var decay: Float
213224
/// Use Nesterov momentum if true.
214225
public var nesterov: Bool
215-
/// The velocity state of the model
216-
public var velocity: Model.AllDifferentiableVariables
226+
/// The velocity state of the model.
227+
public var velocity: Model.TangentVector = .zero
217228
/// The set of steps taken.
218229
public var step: Int = 0
219230

@@ -232,53 +243,38 @@ public class SGD<Model: Layer>: Optimizer
232243
self.momentum = momentum
233244
self.decay = decay
234245
self.nesterov = nesterov
235-
velocity = model.allDifferentiableVariables
236-
for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
237-
velocity[keyPath: kp].resetToZero()
238-
}
239-
for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
240-
velocity[keyPath: kp].resetToZero()
241-
}
242246
}
243247

248+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
244249
public func update(_ model: inout Model.AllDifferentiableVariables,
245250
along direction: Model.TangentVector) {
246251
step += 1
247252
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
248-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
249-
velocity[keyPath: kp] =
250-
momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp]
251-
if nesterov {
252-
model[keyPath: kp] +=
253-
momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp]
254-
} else {
255-
model[keyPath: kp] += velocity[keyPath: kp]
256-
}
257-
}
258-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
259-
velocity[keyPath: kp] =
260-
Double(momentum) * velocity[keyPath: kp] -
261-
Double(learningRate) * direction[keyPath: kp]
262-
if nesterov {
263-
model[keyPath: kp] +=
264-
Double(momentum) * velocity[keyPath: kp] - Double(learningRate) *
265-
direction[keyPath: kp]
266-
} else {
267-
model[keyPath: kp] += velocity[keyPath: kp]
268-
}
253+
velocity = momentum * velocity - direction * learningRate
254+
if nesterov {
255+
model.move(along: momentum * velocity - direction * learningRate)
256+
} else {
257+
model.move(along: velocity)
269258
}
270259
}
260+
261+
public func update(_ model: inout Model,
262+
along direction: Model.TangentVector) {
263+
update(&model.allDifferentiableVariables, along: direction)
264+
}
271265
}
272266

273267
// MARK: - Manifold optimizers
274268

275269
/// A Riemann manifold stochastic gradient descent (SGD) optimizer.
276-
public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
277-
where Model.TangentVector: VectorProtocol, Model.TangentVector.VectorSpaceScalar == Scalar {
270+
public class RiemannSGD<Model: Differentiable>: Optimizer
271+
where Model.TangentVector: VectorProtocol,
272+
Model.TangentVector.VectorSpaceScalar: FloatingPoint {
273+
public typealias Scalar = Model.TangentVector.VectorSpaceScalar
278274
/// The learning rate.
279-
public var learningRate: Scalar
275+
public var learningRate: Model.TangentVector.VectorSpaceScalar
280276

281-
public init(learningRate: Scalar) {
277+
public init(learningRate: Model.TangentVector.VectorSpaceScalar) {
282278
self.learningRate = learningRate
283279
}
284280

@@ -305,6 +301,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
305301
///
306302
public class AdaGrad<Model: Layer>: Optimizer
307303
where Model.AllDifferentiableVariables == Model.TangentVector {
304+
public typealias Model = Model
308305
/// The learning rate.
309306
public var learningRate: Float
310307
/// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2,
@@ -337,6 +334,7 @@ public class AdaGrad<Model: Layer>: Optimizer
337334
}
338335
}
339336

337+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
340338
public func update(_ model: inout Model.AllDifferentiableVariables,
341339
along direction: Model.TangentVector) {
342340
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
@@ -351,4 +349,9 @@ public class AdaGrad<Model: Layer>: Optimizer
351349
(sqrt(alpha[keyPath: kp] + Double(epsilon)))
352350
}
353351
}
352+
353+
public func update(_ model: inout Model,
354+
along direction: Model.TangentVector) {
355+
update(&model.allDifferentiableVariables, along: direction)
356+
}
354357
}

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ final class SequentialTests: XCTestCase {
2929
}
3030
}
3131
var model = Model()
32-
let optimizer = SGD(for: model, learningRate: 0.02)
32+
let sgd = SGD(for: model, learningRate: 0.02)
33+
let rmsprop = RMSProp(for: model, learningRate: 0.02)
34+
let adam = Adam(for: model, learningRate: 0.02)
35+
let adagrad = AdaGrad(for: model, learningRate: 0.02)
3336
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
3437
let y: Tensor<Float> = [0, 1, 1, 0]
3538
Context.local.learningPhase = .training
@@ -38,10 +41,17 @@ final class SequentialTests: XCTestCase {
3841
let ŷ = model(x)
3942
return meanSquaredError(predicted: ŷ, expected: y)
4043
}
41-
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
44+
sgd.update(&model, along: 𝛁model)
45+
sgd.update(&model.allDifferentiableVariables, along: 𝛁model)
46+
rmsprop.update(&model, along: 𝛁model)
47+
rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model)
48+
adam.update(&model, along: 𝛁model)
49+
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
50+
adagrad.update(&model, along: 𝛁model)
51+
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
4252
}
4353
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
44-
[[ 0.4904838], [0.49942452], [0.49740878], [ 0.5106092]])
54+
[[0.47705528], [0.47705528], [0.47705528], [0.47705528]])
4555
}
4656

4757
static var allTests = [

Tests/TensorFlowTests/TrivialModelTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ final class TrivialModelTests: XCTestCase {
5050
let ŷ = classifier(x)
5151
return meanSquaredError(predicted: ŷ, expected: y)
5252
}
53-
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
53+
optimizer.update(&classifier, along: 𝛁model)
5454
}
5555
let ŷ = classifier.inferring(from: x)
5656
XCTAssertEqual(round(ŷ), y)

0 commit comments

Comments
 (0)