Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b7ba0d5

Browse files
dan-zhengsaeta
authored andcommitted
Change momentum and epsilon properties to scalars. (#525)
Change the following properties from `Scalar` to `Tensor<Scalar>`. - `BatchNorm.momentum` - `BatchNorm.epsilon` - `LayerNorm.epsilon` Semantically, these properties are always scalars. Note: this will be an API breaking change in Swift for TensorFlow 0.6. Deprecating the other `BatchNorm` and `LayerNorm` initializers is tricky because it causes ambiguity problems.
1 parent 27c2f21 commit b7ba0d5

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
2525
/// The feature dimension.
2626
@noDerivative public let axis: Int
2727
/// The momentum for the running mean and running variance.
28-
@noDerivative public let momentum: Tensor<Scalar>
28+
@noDerivative public let momentum: Scalar
2929
/// The offset value, also known as beta.
3030
public var offset: Tensor<Scalar>
3131
/// The scale value, also known as gamma.
3232
public var scale: Tensor<Scalar>
3333
/// The variance epsilon value.
34-
@noDerivative public let epsilon: Tensor<Scalar>
34+
@noDerivative public let epsilon: Scalar
3535
/// The running mean.
3636
@noDerivative public let runningMean: Parameter<Scalar>
3737
/// The running variance.
@@ -49,10 +49,10 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
4949
/// - runningVariance: The running variance.
5050
public init(
5151
axis: Int,
52-
momentum: Tensor<Scalar>,
52+
momentum: Scalar,
5353
offset: Tensor<Scalar>,
5454
scale: Tensor<Scalar>,
55-
epsilon: Tensor<Scalar>,
55+
epsilon: Scalar,
5656
runningMean: Tensor<Scalar>,
5757
runningVariance: Tensor<Scalar>
5858
) {
@@ -105,8 +105,8 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
105105
public init(
106106
featureCount: Int,
107107
axis: Int = -1,
108-
momentum: Tensor<Scalar> = Tensor(0.99),
109-
epsilon: Tensor<Scalar> = Tensor(0.001)
108+
momentum: Scalar = 0.99,
109+
epsilon: Scalar = 0.001
110110
) {
111111
self.init(
112112
axis: axis,
@@ -131,14 +131,14 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
131131
/// The axis.
132132
@noDerivative public let axis: Int
133133
/// The variance epsilon value.
134-
@noDerivative public let epsilon: Tensor<Scalar>
134+
@noDerivative public let epsilon: Scalar
135135

136136
/// Creates a layer normalization layer.
137137
public init(
138138
offset: Tensor<Scalar>,
139139
scale: Tensor<Scalar>,
140140
axis: Int,
141-
epsilon: Tensor<Scalar>
141+
epsilon: Scalar
142142
) {
143143
self.offset = offset
144144
self.scale = scale
@@ -155,7 +155,7 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
155155
public init(
156156
featureCount: Int,
157157
axis: Int,
158-
epsilon: Tensor<Scalar> = Tensor(0.001)
158+
epsilon: Scalar = 0.001
159159
) {
160160
self.init(
161161
offset: Tensor(zeros: [featureCount]),

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,15 +1222,12 @@ final class LayerTests: XCTestCase {
12221222
Context.local.learningPhase = .inference
12231223
// This tests for a specific failure that had impacted the MiniGo model.
12241224
let miniGoTensor = Tensor<Float>(randomUniform: [2, 19, 19, 256])
1225-
let miniGoBatchNorm = BatchNorm(
1226-
featureCount: 256,
1227-
momentum: Tensor<Float>(0.95),
1228-
epsilon: Tensor<Float>(1e-5))
1225+
let miniGoBatchNorm = BatchNorm<Float>(featureCount: 256, momentum: 0.95, epsilon: 1e-5)
12291226
let miniGoResult = miniGoBatchNorm(miniGoTensor)
12301227
XCTAssertEqual(miniGoTensor.shape, miniGoResult.shape)
12311228

12321229
let x = Tensor<Float>(rangeFrom: 0, to: 20, stride: 1).reshaped(to: [4,5])
1233-
let epsilon = Tensor<Float>(0.001)
1230+
let epsilon: Float = 0.001
12341231
let bnLayer = BatchNorm<Float>(featureCount: 5, axis: 1, epsilon: epsilon)
12351232
// Test inference before any training.
12361233
assertEqual(bnLayer.inferring(from: x), x / TensorFlow.sqrt(1 + epsilon), accuracy: 1e-5)
@@ -1307,10 +1304,7 @@ final class LayerTests: XCTestCase {
13071304
Context.local.learningPhase = .inference
13081305
// This tests for a specific failure that had impacted the Transformer model.
13091306
let transformerTensor = Tensor<Float>(randomUniform: [1, 1, 768])
1310-
let transformerLayerNorm = LayerNorm(
1311-
featureCount: 768,
1312-
axis: -1,
1313-
epsilon: Tensor<Float>(1e-5))
1307+
let transformerLayerNorm = LayerNorm<Float>(featureCount: 768, axis: -1, epsilon: 1e-5)
13141308
let transformerResult = transformerLayerNorm(transformerTensor)
13151309
XCTAssertEqual(transformerTensor.shape, transformerResult.shape)
13161310
}

0 commit comments

Comments
 (0)