Skip to content

Commit 085217d

Browse files
authored
Deprecate Differentiable.AllDifferentiableVariables. (tensorflow#419)
Remove usages of `AllDifferentiableVariables` and `var allDifferentiableVariables`. Break up expression into two subexpressions to fix "compiler is unable to type-check this expression in reasonable time" error.
1 parent faf540a commit 085217d

File tree

10 files changed

+19
-91
lines changed

10 files changed

+19
-91
lines changed

Sources/TensorFlow/Core/DataTypes.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ extension Int64: TensorFlowIndex {}
8484
public protocol TensorFlowFloatingPoint:
8585
TensorFlowScalar & BinaryFloatingPoint & Differentiable & ElementaryFunctions
8686
where Self.RawSignificand: FixedWidthInteger,
87-
Self == Self.TangentVector,
88-
Self == Self.AllDifferentiableVariables {}
87+
Self == Self.TangentVector {}
8988

9089
extension Float: TensorFlowFloatingPoint {}
9190
extension Double: TensorFlowFloatingPoint {}

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,5 +578,4 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric {
578578

579579
extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
580580
public typealias TangentVector = Tensor
581-
public typealias AllDifferentiableVariables = Tensor
582581
}

Sources/TensorFlow/Layer.swift

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
public protocol Module: Differentiable, KeyPathIterable
1616
where TangentVector: VectorProtocol & ElementaryFunctions &
17-
PointwiseMultiplicative & KeyPathIterable,
18-
AllDifferentiableVariables == TangentVector {
17+
PointwiseMultiplicative & KeyPathIterable {
1918
/// The input type of the layer.
2019
associatedtype Input
2120
/// The output type of the layer.
@@ -55,7 +54,6 @@ public extension Layer {
5554
/// An empty struct representing empty `TangentVector`s for parameterless layers.
5655
public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunctions,
5756
PointwiseMultiplicative, KeyPathIterable {
58-
public typealias AllDifferentiableVariables = EmptyTangentVector
5957
public typealias VectorSpaceScalar = Float
6058

6159
public func adding(_ x: Float) -> EmptyTangentVector { self }
@@ -69,17 +67,12 @@ public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunc
6967
/// A parameterless neural network layer.
7068
///
7169
/// The `TangentVector` of parameterless layers is always `EmptyTangentVector`.
72-
public protocol ParameterlessLayer: Layer where AllDifferentiableVariables == EmptyTangentVector {
70+
public protocol ParameterlessLayer: Layer {
7371
@differentiable
7472
func callAsFunction(_ input: Input) -> Output
7573
}
7674

7775
public extension ParameterlessLayer {
78-
var allDifferentiableVariables: EmptyTangentVector {
79-
get { EmptyTangentVector() }
80-
set {}
81-
}
82-
8376
mutating func move(along direction: EmptyTangentVector) {}
8477
}
8578

@@ -98,7 +91,7 @@ public extension Layer {
9891
@usableFromInline
9992
internal func _vjpInferring(from input: Input)
10093
-> (value: Output, pullback: (Output.TangentVector)
101-
-> (AllDifferentiableVariables, Input.TangentVector)) {
94+
-> (TangentVector, Input.TangentVector)) {
10295
withLearningPhase(LearningPhase.inference) {
10396
let (output, pullback) = appliedForBackpropagation(to: input)
10497
return (output, { v in pullback(v) })

Sources/TensorFlow/Layers/Upsampling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
9393

9494
private func _vjpRepeatingElements(
9595
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
96-
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (AllDifferentiableVariables, Tensor<Scalar>)) {
96+
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
9797
let value = repeatingElements(input, alongAxis: axis, count: count)
9898
return (value, { v in
9999
let splits = Raw.split(

Sources/TensorFlow/Loss.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,5 +287,8 @@ public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
287287
) -> Tensor<Scalar> {
288288
// This numerically stable implementation is based on the TensorFlow Python API.
289289
let maxLogitsWithZero = max(logits, Tensor(0))
290-
return reduction(maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits))))
290+
// Note: `result` is split into two lines to avoid the "compiler is unable to type-check this
291+
// expression in reasonable time" error.
292+
let result = log(1 + exp(-abs(logits)))
293+
return reduction(maxLogitsWithZero - logits * labels + result)
291294
}

Sources/TensorFlow/Optimizers/MomentumBased.swift

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ public class RMSProp<Model: Differentiable>: Optimizer
5555
}
5656

5757
public func update(_ model: inout Model, along direction: Model.TangentVector) {
58-
update(&model.allDifferentiableVariables, along: direction)
59-
}
60-
61-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
62-
public func update(
63-
_ model: inout Model.AllDifferentiableVariables,
64-
along direction: Model.TangentVector
65-
) {
6658
step += 1
6759
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
6860
alpha = alpha * rho + direction .* direction * (1 - rho)
@@ -107,14 +99,6 @@ public class AdaGrad<Model: Differentiable>: Optimizer
10799
}
108100

109101
public func update(_ model: inout Model, along direction: Model.TangentVector) {
110-
update(&model.allDifferentiableVariables, along: direction)
111-
}
112-
113-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
114-
public func update(
115-
_ model: inout Model.AllDifferentiableVariables,
116-
along direction: Model.TangentVector
117-
) {
118102
alpha = rho + direction .* direction
119103
let denominator = Model.TangentVector.sqrt(alpha) + epsilon
120104
model.move(along: -learningRate * direction ./ denominator)
@@ -166,14 +150,6 @@ public class AdaDelta<Model: Differentiable>: Optimizer
166150
}
167151

168152
public func update(_ model: inout Model, along direction: Model.TangentVector) {
169-
update(&model.allDifferentiableVariables, along: direction)
170-
}
171-
172-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
173-
public func update(
174-
_ model: inout Model.AllDifferentiableVariables,
175-
along direction: Model.TangentVector
176-
) {
177153
step += 1
178154
let learningRate = self.learningRate / (1 + decay * Float(step))
179155
averageSquared = rho * averageSquared + (1 - rho) * direction .* direction
@@ -230,15 +206,7 @@ public class Adam<Model: Differentiable>: Optimizer
230206
}
231207

232208
public func update(_ model: inout Model, along direction: Model.TangentVector) {
233-
update(&model.allDifferentiableVariables, along: direction)
234-
}
235-
236-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
237-
public func update(
238-
_ model: inout Model.AllDifferentiableVariables,
239-
along direction: Model.TangentVector
240-
) {
241-
self.step += 1
209+
step += 1
242210
let step = Float(self.step)
243211
let learningRate = self.learningRate * 1 / (1 + decay * step)
244212
// Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
@@ -262,8 +230,7 @@ public class Adam<Model: Differentiable>: Optimizer
262230
public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
263231
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
264232
ElementaryFunctions & KeyPathIterable,
265-
Model.TangentVector.VectorSpaceScalar == Float,
266-
Model.AllDifferentiableVariables == Model.TangentVector {
233+
Model.TangentVector.VectorSpaceScalar == Float {
267234
public typealias Model = Model
268235
/// The learning rate.
269236
public var learningRate: Float
@@ -304,15 +271,7 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
304271
}
305272

306273
public func update(_ model: inout Model, along direction: Model.TangentVector) {
307-
update(&model.allDifferentiableVariables, along: direction)
308-
}
309-
310-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
311-
public func update(
312-
_ model: inout Model.AllDifferentiableVariables,
313-
along direction: Model.TangentVector
314-
) {
315-
self.step += 1
274+
step += 1
316275
let step = Float(self.step)
317276
let learningRate = self.learningRate * 1 / (1 + decay * step)
318277
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
@@ -323,11 +282,11 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
323282

324283
// Update `infinityNorm` using a key path approach because `max(_:_:)` cannot be
325284
// currently applied in a simpler manner.
326-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
285+
for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
327286
infinityNorm[keyPath: kp] = max(
328287
beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
329288
}
330-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
289+
for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
331290
infinityNorm[keyPath: kp] = max(
332291
Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
333292
}
@@ -347,8 +306,7 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
347306
public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
348307
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
349308
ElementaryFunctions & KeyPathIterable,
350-
Model.TangentVector.VectorSpaceScalar == Float,
351-
Model.AllDifferentiableVariables == Model.TangentVector {
309+
Model.TangentVector.VectorSpaceScalar == Float {
352310
public typealias Model = Model
353311
/// The learning rate.
354312
public var learningRate: Float
@@ -390,15 +348,7 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
390348
}
391349

392350
public func update(_ model: inout Model, along direction: Model.TangentVector) {
393-
update(&model.allDifferentiableVariables, along: direction)
394-
}
395-
396-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
397-
public func update(
398-
_ model: inout Model.AllDifferentiableVariables,
399-
along direction: Model.TangentVector
400-
) {
401-
self.step += 1
351+
step += 1
402352
let step = Float(self.step)
403353
let beta1Power = pow(beta1, step)
404354
let beta2Power = pow(beta2, step)
@@ -413,11 +363,11 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
413363

414364
// Update `secondMomentsMax` using a key path approach because `max(_:_:)` cannot be
415365
// currently applied in a simpler manner.
416-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
366+
for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
417367
secondMomentsMax[keyPath: kp] = max(
418368
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
419369
}
420-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
370+
for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
421371
secondMomentsMax[keyPath: kp] = max(
422372
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
423373
}

Sources/TensorFlow/Optimizers/SGD.swift

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,6 @@ public class SGD<Model: Differentiable>: Optimizer
5252
}
5353

5454
public func update(_ model: inout Model, along direction: Model.TangentVector) {
55-
update(&model.allDifferentiableVariables, along: direction)
56-
}
57-
58-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
59-
public func update(
60-
_ model: inout Model.AllDifferentiableVariables,
61-
along direction: Model.TangentVector
62-
) {
6355
step += 1
6456
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
6557
velocity = momentum * velocity - direction * learningRate

Sources/third_party/Experimental/Complex.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ struct Complex<T: FloatingPoint> {
5656

5757
extension Complex: Differentiable where T: Differentiable {
5858
typealias TangentVector = Complex
59-
typealias AllDifferentiableVariables = Complex
6059
}
6160

6261
extension Complex {

Tests/TensorFlowTests/OptimizersTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ final class OptimizerTests: XCTestCase {
5656
let ŷ = classifier(x)
5757
return meanSquaredError(predicted: ŷ, expected: y)
5858
}
59-
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
59+
optimizer.update(&classifier, along: 𝛁model)
6060
}
6161

6262
// trained classifier should return valid values

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,12 @@ final class SequentialTests: XCTestCase {
5252
return meanSquaredError(predicted: ŷ, expected: y)
5353
}
5454
sgd.update(&model, along: 𝛁model)
55-
sgd.update(&model.allDifferentiableVariables, along: 𝛁model)
5655
rmsprop.update(&model, along: 𝛁model)
57-
rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model)
5856
adam.update(&model, along: 𝛁model)
59-
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
6057
adamax.update(&model, along: 𝛁model)
61-
adamax.update(&model.allDifferentiableVariables, along: 𝛁model)
6258
amsgrad.update(&model, along: 𝛁model)
63-
amsgrad.update(&model.allDifferentiableVariables, along: 𝛁model)
6459
adagrad.update(&model, along: 𝛁model)
65-
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
6660
adadelta.update(&model, along: 𝛁model)
67-
adadelta.update(&model.allDifferentiableVariables, along: 𝛁model)
6861
}
6962
}
7063
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),

0 commit comments

Comments
 (0)