Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a41903b

Browse files
lakshya-skyrxwei
authored andcommitted
Add AdaDelta optimizer (#302)
1 parent ec83fa7 commit a41903b

File tree

2 files changed

+124
-29
lines changed

2 files changed

+124
-29
lines changed

Sources/TensorFlow/Optimizer.swift

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -264,41 +264,14 @@ public class SGD<Model: Differentiable>: Optimizer
264264
}
265265
}
266266

267-
// MARK: - Manifold optimizers
268-
269-
/// A Riemann manifold stochastic gradient descent (SGD) optimizer.
270-
public class RiemannSGD<Model: Differentiable>: Optimizer
271-
where Model.TangentVector: VectorProtocol,
272-
Model.TangentVector.VectorSpaceScalar: FloatingPoint {
273-
public typealias Scalar = Model.TangentVector.VectorSpaceScalar
274-
/// The learning rate.
275-
public var learningRate: Model.TangentVector.VectorSpaceScalar
276-
277-
public init(learningRate: Model.TangentVector.VectorSpaceScalar) {
278-
self.learningRate = learningRate
279-
}
280-
281-
public convenience init(
282-
for _: __shared Model,
283-
learningRate: Scalar
284-
) {
285-
self.init(learningRate: learningRate)
286-
}
287-
288-
public func update(_ model: inout Model.AllDifferentiableVariables,
289-
along direction: Model.TangentVector) {
290-
model.move(along: (.zero - direction).scaled(by: learningRate))
291-
}
292-
}
293267

294268
/// AdaGrad optimizer.
295269
///
296270
/// Individually adapts the learning rates of all model parameters by scaling them inversely proportional to
297271
/// the square root of the sum of all the historical squared values of the gradient.
298272
///
299273
/// Reference: ["Adaptive Subgradient Methods for Online Learning and Stochastic Optimization"](
300-
/// http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
301-
///
274+
/// http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
302275
public class AdaGrad<Model: Layer>: Optimizer
303276
where Model.AllDifferentiableVariables == Model.TangentVector {
304277
public typealias Model = Model
@@ -355,3 +328,122 @@ public class AdaGrad<Model: Layer>: Optimizer
355328
update(&model.allDifferentiableVariables, along: direction)
356329
}
357330
}
331+
332+
/// ADADELTA optimizer.
333+
///
334+
/// ADADELTA is a more robust extension of AdaGrad. ADADELTA adapts learning rates based on a moving
335+
/// window of gradient updates rather accumulating all past gradients. ADADELTA can continue to
336+
/// learn even after many update steps.
337+
///
338+
/// Reference: ["ADADELTA: An Adaptive Learning Rate Method"](https://arxiv.org/abs/1212.5701)
339+
public class AdaDelta<Model: Layer>: Optimizer
340+
where Model.AllDifferentiableVariables == Model.TangentVector {
341+
public typealias Model = Model
342+
/// The learning rate.
343+
public var learningRate: Float
344+
/// The decay factor, corresponding to fraction of gradient to keep at each time step.
345+
public var rho: Float
346+
/// A small scalar added to the denominator to improve numerical stability.
347+
public var epsilon: Float
348+
/// The learning rate decay.
349+
public var decay: Float
350+
/// The current step.
351+
public var step: Int = 0
352+
/// The accumulated, exponentially decaying average of squared gradients.
353+
public var averageSquared: Model.TangentVector
354+
/// The accumulated parameter updates.
355+
public var accumulatedDelta: Model.TangentVector
356+
357+
public init(
358+
for model: __shared Model,
359+
learningRate: Float = 1,
360+
rho: Float = 0.95,
361+
epsilon: Float = 1e-6,
362+
decay: Float = 0
363+
) {
364+
precondition(learningRate >= 0, "Learning rate must be non-negative")
365+
precondition(0 <= rho && rho <= 1, "Rho parameter must be between 0 and 1")
366+
precondition(0 <= epsilon, "Epsilon parameter must be non-negative")
367+
precondition(decay >= 0, "Learning rate decay must be non-negative")
368+
369+
self.learningRate = learningRate
370+
self.rho = rho
371+
self.epsilon = epsilon
372+
self.decay = decay
373+
374+
averageSquared = model.allDifferentiableVariables
375+
accumulatedDelta = model.allDifferentiableVariables
376+
377+
for kp in averageSquared.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
378+
averageSquared[keyPath: kp].resetToZero()
379+
accumulatedDelta[keyPath: kp].resetToZero()
380+
}
381+
for kp in averageSquared.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
382+
averageSquared[keyPath: kp].resetToZero()
383+
accumulatedDelta[keyPath: kp].resetToZero()
384+
}
385+
}
386+
387+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
388+
public func update(_ model: inout Model.AllDifferentiableVariables,
389+
along direction: Model.AllDifferentiableVariables) {
390+
step += 1
391+
let learningRate = self.learningRate / (1 + decay * Float(step))
392+
393+
// Update `Tensor<Float>` and `Tensor<Double>` variables.
394+
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
395+
averageSquared[keyPath: kp] *= rho
396+
averageSquared[keyPath: kp] +=
397+
(1 - rho) * (direction[keyPath: kp] * direction[keyPath: kp])
398+
var stepSize = direction[keyPath: kp] *
399+
sqrt(accumulatedDelta[keyPath: kp] + epsilon)
400+
stepSize /= sqrt(averageSquared[keyPath: kp] + epsilon)
401+
model[keyPath: kp] -= learningRate * stepSize
402+
accumulatedDelta[keyPath: kp] *= rho
403+
accumulatedDelta[keyPath: kp] += (1 - rho) * stepSize.squared()
404+
}
405+
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
406+
averageSquared[keyPath: kp] *= Double(rho)
407+
averageSquared[keyPath: kp] +=
408+
(1 - Double(rho)) * (direction[keyPath: kp] * direction[keyPath: kp])
409+
var stepSize = direction[keyPath: kp] *
410+
sqrt(accumulatedDelta[keyPath: kp] + Double(epsilon))
411+
stepSize /= sqrt(averageSquared[keyPath: kp] + Double(epsilon))
412+
model[keyPath: kp] -= Double(learningRate) * stepSize
413+
accumulatedDelta[keyPath: kp] *= Double(rho)
414+
accumulatedDelta[keyPath: kp] += (1 - Double(rho)) * stepSize.squared()
415+
}
416+
}
417+
418+
public func update(_ model: inout Model,
419+
along direction: Model.TangentVector) {
420+
update(&model.allDifferentiableVariables, along: direction)
421+
}
422+
}
423+
424+
// MARK: - Manifold optimizers
425+
426+
/// A Riemann manifold stochastic gradient descent (SGD) optimizer.
427+
public class RiemannSGD<Model: Differentiable>: Optimizer
428+
where Model.TangentVector: VectorProtocol,
429+
Model.TangentVector.VectorSpaceScalar: FloatingPoint {
430+
public typealias Scalar = Model.TangentVector.VectorSpaceScalar
431+
/// The learning rate.
432+
public var learningRate: Model.TangentVector.VectorSpaceScalar
433+
434+
public init(learningRate: Model.TangentVector.VectorSpaceScalar) {
435+
self.learningRate = learningRate
436+
}
437+
438+
public convenience init(
439+
for _: __shared Model,
440+
learningRate: Scalar
441+
) {
442+
self.init(learningRate: learningRate)
443+
}
444+
445+
public func update(_ model: inout Model.AllDifferentiableVariables,
446+
along direction: Model.TangentVector) {
447+
model.move(along: (.zero - direction).scaled(by: learningRate))
448+
}
449+
}

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ final class SequentialTests: XCTestCase {
3333
let rmsprop = RMSProp(for: model, learningRate: 0.02)
3434
let adam = Adam(for: model, learningRate: 0.02)
3535
let adagrad = AdaGrad(for: model, learningRate: 0.02)
36+
let adadelta = AdaDelta(for: model, learningRate: 0.02)
3637
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
3738
let y: Tensor<Float> = [0, 1, 1, 0]
3839
Context.local.learningPhase = .training
@@ -49,9 +50,11 @@ final class SequentialTests: XCTestCase {
4950
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
5051
adagrad.update(&model, along: 𝛁model)
5152
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
53+
adadelta.update(&model, along: 𝛁model)
54+
adadelta.update(&model.allDifferentiableVariables, along: 𝛁model)
5255
}
5356
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
54-
[[0.47705528], [0.47705528], [0.47705528], [0.47705528]])
57+
[[0.47683996], [0.47683996], [0.47683996], [0.47683996]])
5558
}
5659

5760
static var allTests = [

0 commit comments

Comments
 (0)