Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 268055e

Browse files
authored
Reverting inference behavior for BatchNorm. (#426)
1 parent 43cfdd6 commit 268055e

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,8 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
8282
let inv = rsqrt(moments.variance + epsilon) * scale.reshaped(to: moments.variance.shape)
8383
return (input - moments.mean) * inv + offset.reshaped(to: moments.mean.shape)
8484
case .inference:
85-
let scaleShape = runningVariance.value.shape
86-
let offsetShape = runningMean.value.shape
87-
let inv = rsqrt(runningVariance.value + epsilon) * scale.reshaped(to: scaleShape)
88-
return (input - runningMean.value) * inv + offset.reshaped(to: offsetShape)
85+
let inv = rsqrt(runningVariance.value + epsilon) * scale
86+
return (input - runningMean.value) * inv + offset
8987
}
9088
}
9189

0 commit comments

Comments
 (0)