Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 47de233

Browse files
authored
Correcting calculation of Glorot uniform for Tensors (#576)
* Correcting calculation of Glorot uniform for Tensors. * Restored the GRU test values to what they were before they were adjusted for the new uniform.
1 parent da8c8d6 commit 47de233

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
495495
/// - shape: The dimensions of the tensor.
496496
init(glorotUniform shape: TensorShape, seed: TensorFlowSeed = Context.local.randomSeed) {
497497
let (fanIn, fanOut) = shape.fans()
498-
let limit = Tensor<Scalar>(6 / Scalar(fanIn + fanOut))
498+
let limit = Tensor<Scalar>(Scalar.sqrt(6 / Scalar(fanIn + fanOut)))
499499
self.init(randomUniform: shape, lowerBound: -limit, upperBound: limit, seed: seed)
500500
}
501501

Tests/TensorFlowTests/InitializerTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ final class InitializerTests: XCTestCase {
9797
// Constants for testing distribution based initializers.
9898
private let fcShape = TensorShape([200, 100])
9999
private let convShape = TensorShape([25, 25, 20, 20])
100-
private let tolerance = Float(3e-2)
101100

102101
func testDistribution(
103102
_ t: Tensor<Float>,
104103
expectedMean: Float? = nil,
105104
expectedStandardDeviation: Float? = nil,
106105
expectedMin: Float? = nil,
107-
expectedMax: Float? = nil
106+
expectedMax: Float? = nil,
107+
tolerance: Float = 3e-2
108108
) {
109109
if let expectedMean = expectedMean {
110110
let mean = t.mean().scalarized()
@@ -159,7 +159,7 @@ final class InitializerTests: XCTestCase {
159159
let spatialSize = convShape[0..<2].contiguousSize
160160
let (fanIn, fanOut) = (convShape[2] * spatialSize, convShape[3] * spatialSize)
161161
let stdDev = sqrt(Float(2.0) / Float(fanIn + fanOut))
162-
testDistribution(t, expectedMean: 0, expectedStandardDeviation: stdDev)
162+
testDistribution(t, expectedMean: 0, expectedStandardDeviation: stdDev, tolerance: 1e-4)
163163
}
164164

165165
func testGlorotNormal() {

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,10 +1147,10 @@ final class LayerTests: XCTestCase {
11471147
}
11481148
assertEqual(
11491149
outputs.map { $0.value.squeezingShape(at: 0) }[0],
1150-
[[ 0.14798240, 0.14295710, -0.09766942, -0.131820890],
1151-
[ 0.15757358, 0.19475500, -0.12810913, -0.112212844],
1152-
[ 0.16168950, 0.20306197, -0.13058113, -0.123917180],
1153-
[ 0.16325668, 0.20822097, -0.13273866, -0.121018395]],
1150+
[[ 0.20775771, 0.20080023, -0.13768704, -0.18534681],
1151+
[ 0.22666009, 0.30019346, -0.19720285, -0.14683801],
1152+
[ 0.23758979, 0.32101023, -0.20359215, -0.1787096],
1153+
[ 0.24337786, 0.3389194, -0.21143384, -0.1675081]],
11541154
accuracy: 1e-6)
11551155
}
11561156
// TODO: Figure out why the following is numerically unstable.
@@ -1178,17 +1178,17 @@ final class LayerTests: XCTestCase {
11781178
}
11791179
assertEqual(
11801180
outputs.map { $0.cell.squeezingShape(at: 0) }[0],
1181-
[[ 0.047114454, 0.013959665, -0.030737250, -0.038524970],
1182-
[ 0.069171116, 0.020617897, -0.044740470, -0.058878290],
1183-
[ 0.079530790, 0.023841830, -0.051080680, -0.069567055],
1184-
[ 0.084416830, 0.025424266, -0.053918116, -0.075140170]],
1181+
[[ 0.08981595, 0.027691621, -0.059235442, -0.075101905],
1182+
[ 0.12952757, 0.040402323, -0.084273980, -0.116252676],
1183+
[ 0.14727503, 0.046511370, -0.094689950, -0.138459030],
1184+
[ 0.15532997, 0.049573865, -0.098824400, -0.150242210]],
11851185
accuracy: 1e-6)
11861186
assertEqual(
11871187
outputs.map { $0.hidden.squeezingShape(at: 0) }[0],
1188-
[[ 0.024117637, 0.0066833394, -0.015753632, -0.019533360],
1189-
[ 0.035230752, 0.0098582430, -0.022934474, -0.029750597],
1190-
[ 0.040405065, 0.0113919870, -0.026185552, -0.035087958],
1191-
[ 0.042834233, 0.0121438510, -0.027640648, -0.037863784]],
1188+
[[ 0.046985064, 0.012670102, -0.031083463, -0.038572006],
1189+
[ 0.066482050, 0.018388016, -0.044252350, -0.058907583],
1190+
[ 0.074910110, 0.021107012, -0.049724963, -0.069670826],
1191+
[ 0.078670055, 0.022462710, -0.051899005, -0.075331904]],
11921192
accuracy: 1e-6)
11931193
}
11941194
}
@@ -1207,11 +1207,13 @@ final class LayerTests: XCTestCase {
12071207
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
12081208
return rnn(inputs)
12091209
}
1210-
XCTAssertEqual(outputs.map { $0.hidden },
1211-
[[[0.12806869, 0.12806869, 0.12806869, 0.12806869]],
1212-
[[0.2007559, 0.2007559, 0.2007559, 0.2007559]],
1213-
[[0.23432666, 0.23432666, 0.23432666, 0.23432666]],
1214-
[[0.24788898, 0.24788898, 0.24788898, 0.24788898]]])
1210+
assertEqual(
1211+
outputs.map { $0.hidden }[0],
1212+
[[0.1193780, 0.1193780, 0.1193780, 0.1193780],
1213+
[0.1887644, 0.1887644, 0.1887644, 0.1887644],
1214+
[0.2230835, 0.2230835, 0.2230835, 0.2230835],
1215+
[0.2383619, 0.2383619, 0.2383619, 0.2383619]],
1216+
accuracy: 1e-5)
12151217
}
12161218
}
12171219

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ final class SequentialTests: XCTestCase {
6363
}
6464
}
6565
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
66-
[[0.5567076], [0.5567076], [0.5567076], [0.5567076]])
66+
[[0.50378805], [0.50378805], [0.50378805], [0.50378805]])
6767
}
6868

6969
static var allTests = [

0 commit comments

Comments
 (0)