This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +10
-6
lines changed Expand file tree Collapse file tree 2 files changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -47,14 +47,16 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
47
47
let squaredDiff : Tensor = Raw . squaredDifference ( self , mean)
48
48
let variance = squaredDiff. mean ( alongAxes: axis)
49
49
50
- let diff = self - mean
50
+ let diff = self - mean
51
51
let inv = rsqrt ( variance + epsilon)
52
52
let norm = diff * inv
53
53
54
- let dNorm = v * scale
54
+ let dNorm = v * scale
55
55
let dVariance = - ( dNorm * diff) . sum ( alongAxes: axis) / 2 * pow( inv, - 3 )
56
- let dMean = ( - dNorm * inv) . sum ( alongAxes: axis) +
57
- dVariance * ( - diff * 2 ) . mean ( alongAxes: axis)
56
+ // Note: `dMean` is split into two lines to avoid the "compiler is unable to type-check
57
+ // this expression in reasonable time" error.
58
+ var dMean = ( - dNorm * inv) . sum ( alongAxes: axis)
59
+ dMean = dMean + dVariance * ( - diff * 2 ) . mean ( alongAxes: axis)
58
60
let dOffset = v. sum ( alongAxes: axis)
59
61
let dScale = ( norm * v) . sum ( alongAxes: axis)
60
62
let dim = Tensor ( Tensor < Int32 > ( self . shapeTensor [ axis] ) )
Original file line number Diff line number Diff line change @@ -105,8 +105,10 @@ public class Adam<Model: Layer>: Optimizer
105
105
along direction: Model . AllDifferentiableVariables ) {
106
106
step += 1
107
107
let learningRate = self . learningRate * 1 / ( 1 + decay * Float( step) )
108
- let stepSize = learningRate * ( sqrt ( 1 - pow( beta2, Float ( step) ) ) /
109
- ( 1 - pow( beta1, Float ( step) ) ) )
108
+ // Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
109
+ // this expression in reasonable time" error.
110
+ var stepSize = learningRate * sqrt( 1 - pow( beta2, Float ( step) ) )
111
+ stepSize = stepSize / ( 1 - pow( beta1, Float ( step) ) )
110
112
// Update Float & Double Tensor variables.
111
113
for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
112
114
firstMoments [ keyPath: kp] =
You can’t perform that action at this time.
0 commit comments