Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b5a49b7

Browse files
jon-towmarcrasi
authored andcommitted
[Initializers] Refactor random and variance-scaling initializers (#335)
1 parent fb164d1 commit b5a49b7

File tree

4 files changed

+160
-62
lines changed

4 files changed

+160
-62
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -453,63 +453,75 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
453453
}
454454
}
455455

456-
// TODO: Can become fileprivate after the 0.4 release.
457-
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
458-
static func glorot(
459-
fromStandardUniform randomUniform: __shared Tensor<Scalar>,
460-
shape: __shared TensorShape
461-
) -> Tensor<Scalar> {
462-
let spatialDimCount = shape.count - 2
463-
let receptiveField = shape[0..<spatialDimCount].contiguousSize
464-
let fanIn = shape[shape.count - 2] * receptiveField
465-
let fanOut = shape[shape.count - 1] * receptiveField
466-
let minusOneToOne = 2 * randomUniform - 1
467-
return Scalar.sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
456+
//===------------------------------------------------------------------------------------------===//
457+
// Variance Scaling
458+
//===------------------------------------------------------------------------------------------===//
459+
460+
fileprivate extension TensorShape {
461+
// Returns the `fanIn` and `fanOut` counts for `TensorShape`s where the last two axes represent
462+
// the input channel count and output channel count, respectively.
463+
func fans() -> (in: Int, out: Int) {
464+
precondition(
465+
count > 1,
466+
"Fans cannot be computed for tensors with fewer than 2 dimensions. Got: \(count)")
467+
468+
// Fans for a 2-D tensor, e.g. `Dense`/`Embedding` weights.
469+
if count == 2 {
470+
return (self[0], self[1])
471+
}
472+
// Fans for tensors with rank greater than `2`, specifically convolution filters.
473+
let lastSpatialAxis = endIndex - 3
474+
let spatialSize = self[0..<(lastSpatialAxis + 1)].contiguousSize
475+
let inputAxis = endIndex - 2
476+
let fanIn = self[inputAxis] * spatialSize
477+
let outputAxis = endIndex - 1
478+
let fanOut = self[outputAxis] * spatialSize
479+
return (fanIn, fanOut)
468480
}
469481
}
470482

471483
public extension Tensor where Scalar: TensorFlowFloatingPoint {
472-
/// Creates a tensor by performing Glorot uniform initialization for the specified shape,
473-
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
474-
/// generated by the default random number generator, where limit is
484+
/// Creates a tensor with the specified shape by performing Glorot uniform initialization.
485+
///
486+
/// It draws random samples from a uniform distribution between `-limit` and `limit`
487+
/// generated by the default random number generator, where `limit` is
475488
/// `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
476-
/// features multiplied by the receptive field if present.
489+
/// features multiplied by the receptive field size.
490+
///
491+
/// Reference: ["Understanding the difficulty of training deep feedforward neural networks"](
492+
/// http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
477493
///
478494
/// - Parameters:
479495
/// - shape: The dimensions of the tensor.
480496
init(glorotUniform shape: TensorShape, seed: TensorFlowSeed = Context.local.randomSeed) {
481-
let uniform = Tensor(randomUniform: shape, seed: seed)
482-
self = Tensor.glorot(fromStandardUniform: uniform, shape: shape)
497+
let (fanIn, fanOut) = shape.fans()
498+
let limit = Tensor<Scalar>(6 / Scalar(fanIn + fanOut))
499+
self.init(randomUniform: shape, lowerBound: -limit, upperBound: limit, seed: seed)
483500
}
484-
}
485501

486-
// TODO: Can become fileprivate after the 0.4 release.
487-
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
488-
static func glorot(
489-
fromStandardNormal standardNormal: __shared Tensor<Scalar>,
490-
shape: __shared TensorShape
491-
) -> Tensor<Scalar> {
492-
let spatialDimCount = shape.count - 2
493-
let receptiveField = shape[0..<spatialDimCount].contiguousSize
494-
let fanIn = shape[shape.count - 2] * receptiveField
495-
let fanOut = shape[shape.count - 1] * receptiveField
496-
let minusOneToOne = 2 * standardNormal - 1
497-
return Scalar.sqrt(Scalar(2) / Scalar(fanIn + fanOut)) * minusOneToOne
498-
}
499-
}
500-
501-
public extension Tensor where Scalar: TensorFlowFloatingPoint {
502-
/// Creates a tensor by performing Glorot normal initialization for the specified shape,
503-
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
504-
/// generated by the default random number generator, where limit is
505-
/// `sqrt(2 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
506-
/// features multiplied by the receptive field if present.
502+
/// Creates a tensor with the specified shape by performing Glorot normal initialization.
503+
///
504+
/// It draws random samples from a truncated normal distribution centered on `0` with
505+
/// standard deviation `sqrt(2 / (fanIn + fanOut))`generated by the default random number
506+
/// generator, where `fanIn`/`fanOut` represent the number of input and output features
507+
/// multiplied by the receptive field size.
508+
///
509+
/// Reference: ["Understanding the difficulty of training deep feedforward neural networks"](
510+
/// http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
507511
///
508512
/// - Parameters:
509513
/// - shape: The dimensions of the tensor.
510514
init(glorotNormal shape: TensorShape, seed: TensorFlowSeed = Context.local.randomSeed) {
511-
let normal = Tensor(randomNormal: shape, seed: seed)
512-
self = Tensor.glorot(fromStandardNormal: normal, shape: shape)
515+
let (fanIn, fanOut) = shape.fans()
516+
var standardDeviation = Tensor<Scalar>(Scalar.sqrt(2 / Scalar(fanIn + fanOut)))
517+
// Standard deviation of truncated standard normal between `-2` and `2` standard deviations.
518+
let truncationDeviation = Tensor<Scalar>(0.87962566103423978)
519+
standardDeviation /= truncationDeviation // Smooths the tails of the clipped normal.
520+
self.init(
521+
randomTruncatedNormal: shape,
522+
mean: Tensor<Scalar>(0),
523+
standardDeviation: standardDeviation,
524+
seed: seed)
513525
}
514526
}
515527

Tests/TensorFlowTests/InitializerTests.swift

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,88 @@ final class InitializerTests: XCTestCase {
9494
XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), i8s.array)
9595
}
9696

97+
// Constants for testing distribution based initializers.
98+
private let fcShape = TensorShape([200, 100])
99+
private let convShape = TensorShape([25, 25, 20, 20])
100+
private let tolerance = Float(3e-2)
101+
102+
func testDistribution(
103+
_ t: Tensor<Float>,
104+
expectedMean: Float? = nil,
105+
expectedStandardDeviation: Float? = nil,
106+
expectedMin: Float? = nil,
107+
expectedMax: Float? = nil
108+
) {
109+
if let expectedMean = expectedMean {
110+
let mean = t.mean().scalarized()
111+
XCTAssertTrue(abs(mean - expectedMean) < tolerance)
112+
}
113+
if let expectedStandardDeviation = expectedStandardDeviation {
114+
let standardDeviation = t.standardDeviation().scalarized()
115+
XCTAssertTrue(abs(standardDeviation - expectedStandardDeviation) < tolerance)
116+
}
117+
if let expectedMin = expectedMin {
118+
let min = t.min().scalarized()
119+
XCTAssertTrue(abs(min - expectedMin) < tolerance)
120+
}
121+
if let expectedMax = expectedMax {
122+
let max = t.max().scalarized()
123+
XCTAssertTrue(abs(max - expectedMax) < tolerance)
124+
}
125+
}
126+
127+
func testRandomUniform() {
128+
do {
129+
let t = Tensor<Float>(
130+
randomUniform: fcShape,
131+
lowerBound: Tensor(2),
132+
upperBound: Tensor(3))
133+
testDistribution(t, expectedMean: 2.5, expectedMin: 2, expectedMax: 3)
134+
}
135+
do {
136+
let t = Tensor<Float>(
137+
randomUniform: fcShape,
138+
lowerBound: Tensor(-1),
139+
upperBound: Tensor(1))
140+
testDistribution(t, expectedMean: 0, expectedMin: -1, expectedMax: 1)
141+
}
142+
}
143+
144+
func testRandomNormal() {
145+
let t = Tensor<Float>(
146+
randomNormal: convShape,
147+
mean: Tensor(1),
148+
standardDeviation: Tensor(2))
149+
testDistribution(t, expectedMean: 1, expectedStandardDeviation: 2)
150+
}
151+
152+
func testRandomTruncatedNormal() {
153+
let t = Tensor<Float>(randomTruncatedNormal: convShape)
154+
testDistribution(t, expectedMean: 0, expectedMin: -2, expectedMax: 2)
155+
}
156+
157+
func testGlorotUniform() {
158+
let t = Tensor<Float>(glorotUniform: convShape)
159+
let spatialSize = convShape[0..<2].contiguousSize
160+
let (fanIn, fanOut) = (convShape[2] * spatialSize, convShape[3] * spatialSize)
161+
let stdDev = sqrt(Float(2.0) / Float(fanIn + fanOut))
162+
testDistribution(t, expectedMean: 0, expectedStandardDeviation: stdDev)
163+
}
164+
165+
func testGlorotNormal() {
166+
let t = Tensor<Float>(glorotNormal: convShape)
167+
let spatialSize = convShape[0..<2].contiguousSize
168+
let (fanIn, fanOut) = (convShape[2] * spatialSize, convShape[3] * spatialSize)
169+
let stdDev = sqrt(Float(2.0) / Float(fanIn + fanOut))
170+
testDistribution(t, expectedMean: 0, expectedStandardDeviation: stdDev)
171+
}
172+
97173
func testOrthogonalShapesValues() {
98174
for shape in [[10, 10], [10, 9, 8], [100, 5, 5], [50, 40], [3, 3, 32, 64]] {
99175
// Check the shape.
100176
var t = Tensor<Float>(orthogonal: TensorShape(shape))
101177
XCTAssertEqual(shape, t.shape.dimensions)
102-
178+
103179
// Check orthogonality by computing the inner product.
104180
t = t.reshaped(to: [t.shape.dimensions.dropLast().reduce(1, *), t.shape[t.rank - 1]])
105181
if t.shape[0] > t.shape[1] {
@@ -120,6 +196,11 @@ final class InitializerTests: XCTestCase {
120196
("testArrayConversion", testArrayConversion),
121197
("testDataTypeCast", testDataTypeCast),
122198
("testBoolToNumericCast", testBoolToNumericCast),
199+
("testRandomUniform", testRandomUniform),
200+
("testRandomNormal", testRandomNormal),
201+
("testRandomTruncatedNormal", testRandomTruncatedNormal),
202+
("testGlorotUniform", testGlorotUniform),
203+
("testGlorotNormal", testGlorotNormal),
123204
("testOrthogonalShapesValues", testOrthogonalShapesValues)
124205
]
125206
}

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,11 +1120,13 @@ final class LayerTests: XCTestCase {
11201120
let (outputs, _) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
11211121
return rnn(inputs)
11221122
}
1123-
XCTAssertEqual(outputs.map { $0.value },
1124-
[[[ 0.20775771, 0.20080023, -0.13768704, -0.18534681]],
1125-
[[ 0.22666009, 0.30019346, -0.19720285, -0.14683801]],
1126-
[[ 0.23758979, 0.32101023, -0.20359215, -0.1787096]],
1127-
[[ 0.24337786, 0.3389194, -0.21143384, -0.1675081]]])
1123+
assertEqual(
1124+
outputs.map { $0.value.squeezingShape(at: 0) }[0],
1125+
[[ 0.14798240, 0.14295710, -0.09766942, -0.131820890],
1126+
[ 0.15757358, 0.19475500, -0.12810913, -0.112212844],
1127+
[ 0.16168950, 0.20306197, -0.13058113, -0.123917180],
1128+
[ 0.16325668, 0.20822097, -0.13273866, -0.121018395]],
1129+
accuracy: 1e-6)
11281130
}
11291131
// TODO: Figure out why the following is numerically unstable.
11301132
// let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
@@ -1149,18 +1151,20 @@ final class LayerTests: XCTestCase {
11491151
let (outputs, _) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
11501152
return rnn(inputs)
11511153
}
1152-
XCTAssertEqual(
1153-
outputs.map { $0.cell },
1154-
[[[ 0.08981595, 0.027691621, -0.059235442, -0.075101905]],
1155-
[[ 0.12952757, 0.040402323, -0.084273980, -0.116252676]],
1156-
[[ 0.14727503, 0.046511370, -0.094689950, -0.138459030]],
1157-
[[ 0.15532997, 0.049573865, -0.098824400, -0.150242210]]])
1158-
XCTAssertEqual(
1159-
outputs.map { $0.hidden },
1160-
[[[ 0.046985064, 0.012670102, -0.031083463, -0.038572006]],
1161-
[[ 0.066482050, 0.018388016, -0.044252350, -0.058907583]],
1162-
[[ 0.074910110, 0.021107012, -0.049724963, -0.069670826]],
1163-
[[ 0.078670055, 0.022462710, -0.051899005, -0.075331904]]])
1154+
assertEqual(
1155+
outputs.map { $0.cell.squeezingShape(at: 0) }[0],
1156+
[[ 0.047114454, 0.013959665, -0.030737250, -0.038524970],
1157+
[ 0.069171116, 0.020617897, -0.044740470, -0.058878290],
1158+
[ 0.079530790, 0.023841830, -0.051080680, -0.069567055],
1159+
[ 0.084416830, 0.025424266, -0.053918116, -0.075140170]],
1160+
accuracy: 1e-6)
1161+
assertEqual(
1162+
outputs.map { $0.hidden.squeezingShape(at: 0) }[0],
1163+
[[ 0.024117637, 0.0066833394, -0.015753632, -0.019533360],
1164+
[ 0.035230752, 0.0098582430, -0.022934474, -0.029750597],
1165+
[ 0.040405065, 0.0113919870, -0.026185552, -0.035087958],
1166+
[ 0.042834233, 0.0121438510, -0.027640648, -0.037863784]],
1167+
accuracy: 1e-6)
11641168
}
11651169
}
11661170
}

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ final class SequentialTests: XCTestCase {
6060
adadelta.update(&model, along: 𝛁model)
6161
}
6262
}
63-
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
64-
[[0.4884567], [0.4884567], [0.4884567], [0.4884567]])
63+
assertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
64+
[[0.5115531], [0.5115531], [0.5115531], [0.5115531]],
65+
accuracy: 1e-6)
6566
}
6667

6768
static var allTests = [

0 commit comments

Comments
 (0)