@@ -453,63 +453,75 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
453
453
}
454
454
}
455
455
456
- // TODO: Can become fileprivate after the 0.4 release.
457
- internal extension Tensor where Scalar: TensorFlowFloatingPoint {
458
- static func glorot(
459
- fromStandardUniform randomUniform: __shared Tensor< Scalar > ,
460
- shape: __shared TensorShape
461
- ) -> Tensor < Scalar > {
462
- let spatialDimCount = shape. count - 2
463
- let receptiveField = shape [ 0 ..< spatialDimCount] . contiguousSize
464
- let fanIn = shape [ shape. count - 2 ] * receptiveField
465
- let fanOut = shape [ shape. count - 1 ] * receptiveField
466
- let minusOneToOne = 2 * randomUniform - 1
467
- return Scalar . sqrt ( Scalar ( 6 ) / Scalar( fanIn + fanOut) ) * minusOneToOne
456
+ //===------------------------------------------------------------------------------------------===//
457
+ // Variance Scaling
458
+ //===------------------------------------------------------------------------------------------===//
459
+
460
+ fileprivate extension TensorShape {
461
+ // Returns the `fanIn` and `fanOut` counts for `TensorShape`s where the last two axes represent
462
+ // the input channel count and output channel count, respectively.
463
+ func fans( ) -> ( in: Int , out: Int ) {
464
+ precondition (
465
+ count > 1 ,
466
+ " Fans cannot be computed for tensors with fewer than 2 dimensions. Got: \( count) " )
467
+
468
+ // Fans for a 2-D tensor, e.g. `Dense`/`Embedding` weights.
469
+ if count == 2 {
470
+ return ( self [ 0 ] , self [ 1 ] )
471
+ }
472
+ // Fans for tensors with rank greater than `2`, specifically convolution filters.
473
+ let lastSpatialAxis = endIndex - 3
474
+ let spatialSize = self [ 0 ..< ( lastSpatialAxis + 1 ) ] . contiguousSize
475
+ let inputAxis = endIndex - 2
476
+ let fanIn = self [ inputAxis] * spatialSize
477
+ let outputAxis = endIndex - 1
478
+ let fanOut = self [ outputAxis] * spatialSize
479
+ return ( fanIn, fanOut)
468
480
}
469
481
}
470
482
471
483
public extension Tensor where Scalar: TensorFlowFloatingPoint {
472
- /// Creates a tensor by performing Glorot uniform initialization for the specified shape,
473
- /// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
474
- /// generated by the default random number generator, where limit is
484
+ /// Creates a tensor with the specified shape by performing Glorot uniform initialization.
485
+ ///
486
+ /// It draws random samples from a uniform distribution between `-limit` and `limit`
487
+ /// generated by the default random number generator, where `limit` is
475
488
/// `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
476
- /// features multiplied by the receptive field if present.
489
+ /// features multiplied by the receptive field size.
490
+ ///
491
+ /// Reference: ["Understanding the difficulty of training deep feedforward neural networks"](
492
+ /// http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
477
493
///
478
494
/// - Parameters:
479
495
/// - shape: The dimensions of the tensor.
480
496
init ( glorotUniform shape: TensorShape , seed: TensorFlowSeed = Context . local. randomSeed) {
481
- let uniform = Tensor ( randomUniform: shape, seed: seed)
482
- self = Tensor . glorot ( fromStandardUniform: uniform, shape: shape)
497
+ let ( fanIn, fanOut) = shape. fans ( )
498
+ let limit = Tensor < Scalar > ( 6 / Scalar( fanIn + fanOut) )
499
+ self . init ( randomUniform: shape, lowerBound: - limit, upperBound: limit, seed: seed)
483
500
}
484
- }
485
501
486
- // TODO: Can become fileprivate after the 0.4 release.
487
- internal extension Tensor where Scalar: TensorFlowFloatingPoint {
488
- static func glorot(
489
- fromStandardNormal standardNormal: __shared Tensor< Scalar > ,
490
- shape: __shared TensorShape
491
- ) -> Tensor < Scalar > {
492
- let spatialDimCount = shape. count - 2
493
- let receptiveField = shape [ 0 ..< spatialDimCount] . contiguousSize
494
- let fanIn = shape [ shape. count - 2 ] * receptiveField
495
- let fanOut = shape [ shape. count - 1 ] * receptiveField
496
- let minusOneToOne = 2 * standardNormal - 1
497
- return Scalar . sqrt ( Scalar ( 2 ) / Scalar( fanIn + fanOut) ) * minusOneToOne
498
- }
499
- }
500
-
501
- public extension Tensor where Scalar: TensorFlowFloatingPoint {
502
- /// Creates a tensor by performing Glorot normal initialization for the specified shape,
503
- /// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
504
- /// generated by the default random number generator, where limit is
505
- /// `sqrt(2 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
506
- /// features multiplied by the receptive field if present.
502
+ /// Creates a tensor with the specified shape by performing Glorot normal initialization.
503
+ ///
504
+ /// It draws random samples from a truncated normal distribution centered on `0` with
505
+ /// standard deviation `sqrt(2 / (fanIn + fanOut))`generated by the default random number
506
+ /// generator, where `fanIn`/`fanOut` represent the number of input and output features
507
+ /// multiplied by the receptive field size.
508
+ ///
509
+ /// Reference: ["Understanding the difficulty of training deep feedforward neural networks"](
510
+ /// http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
507
511
///
508
512
/// - Parameters:
509
513
/// - shape: The dimensions of the tensor.
510
514
init ( glorotNormal shape: TensorShape , seed: TensorFlowSeed = Context . local. randomSeed) {
511
- let normal = Tensor ( randomNormal: shape, seed: seed)
512
- self = Tensor . glorot ( fromStandardNormal: normal, shape: shape)
515
+ let ( fanIn, fanOut) = shape. fans ( )
516
+ var standardDeviation = Tensor < Scalar > ( Scalar . sqrt ( 2 / Scalar( fanIn + fanOut) ) )
517
+ // Standard deviation of truncated standard normal between `-2` and `2` standard deviations.
518
+ let truncationDeviation = Tensor < Scalar > ( 0.87962566103423978 )
519
+ standardDeviation /= truncationDeviation // Smooths the tails of the clipped normal.
520
+ self . init (
521
+ randomTruncatedNormal: shape,
522
+ mean: Tensor < Scalar > ( 0 ) ,
523
+ standardDeviation: standardDeviation,
524
+ seed: seed)
513
525
}
514
526
}
515
527
0 commit comments