Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 79bcb7c

Browse files
committed
Change BatchNorm momentum and epsilon to scalars.
Change `BatchNorm.momentum` and `BatchNorm.epsilon` to `Scalar` instead of `Tensor<Scalar>`. Semantically, these properties are scalars. It is not clear why `BatchNorm` originally defined these properties as `Tensor<Scalar>` since the beginning: 6ca3813
1 parent 6309c55 commit 79bcb7c

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
2525
/// The feature dimension.
2626
@noDerivative public let axis: Int
2727
/// The momentum for the running mean and running variance.
28-
@noDerivative public let momentum: Tensor<Scalar>
28+
@noDerivative public let momentum: Scalar
2929
/// The offset value, also known as beta.
3030
public var offset: Tensor<Scalar>
3131
/// The scale value, also known as gamma.
3232
public var scale: Tensor<Scalar>
3333
/// The variance epsilon value.
34-
@noDerivative public let epsilon: Tensor<Scalar>
34+
@noDerivative public let epsilon: Scalar
3535
/// The running mean.
3636
@noDerivative public let runningMean: Parameter<Scalar>
3737
/// The running variance.
@@ -49,10 +49,10 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
4949
/// - runningVariance: The running variance.
5050
public init(
5151
axis: Int,
52-
momentum: Tensor<Scalar>,
52+
momentum: Scalar,
5353
offset: Tensor<Scalar>,
5454
scale: Tensor<Scalar>,
55-
epsilon: Tensor<Scalar>,
55+
epsilon: Scalar,
5656
runningMean: Tensor<Scalar>,
5757
runningVariance: Tensor<Scalar>
5858
) {
@@ -105,8 +105,8 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
105105
public init(
106106
featureCount: Int,
107107
axis: Int = -1,
108-
momentum: Tensor<Scalar> = Tensor(0.99),
109-
epsilon: Tensor<Scalar> = Tensor(0.001)
108+
momentum: Scalar = 0.99,
109+
epsilon: Scalar = 0.001
110110
) {
111111
self.init(
112112
axis: axis,

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,15 +1222,12 @@ final class LayerTests: XCTestCase {
12221222
Context.local.learningPhase = .inference
12231223
// This tests for a specific failure that had impacted the MiniGo model.
12241224
let miniGoTensor = Tensor<Float>(randomUniform: [2, 19, 19, 256])
1225-
let miniGoBatchNorm = BatchNorm(
1226-
featureCount: 256,
1227-
momentum: Tensor<Float>(0.95),
1228-
epsilon: Tensor<Float>(1e-5))
1225+
let miniGoBatchNorm = BatchNorm<Float>(featureCount: 256, momentum: 0.95, epsilon: 1e-5)
12291226
let miniGoResult = miniGoBatchNorm(miniGoTensor)
12301227
XCTAssertEqual(miniGoTensor.shape, miniGoResult.shape)
12311228

12321229
let x = Tensor<Float>(rangeFrom: 0, to: 20, stride: 1).reshaped(to: [4,5])
1233-
let epsilon = Tensor<Float>(0.001)
1230+
let epsilon: Float = 0.001
12341231
let bnLayer = BatchNorm<Float>(featureCount: 5, axis: 1, epsilon: epsilon)
12351232
// Test inference before any training.
12361233
assertEqual(bnLayer.inferring(from: x), x / TensorFlow.sqrt(1 + epsilon), accuracy: 1e-5)

0 commit comments

Comments
 (0)