@@ -410,6 +410,90 @@ final class LayerTests: XCTestCase {
410
410
let expected = Tensor < Float > ( [ [ 0.0 ] , [ 0.7615942 ] , [ 0.9640276 ] , [ 0.9950547 ] , [ 0.9993292 ] ] )
411
411
XCTAssertEqual ( output, expected)
412
412
}
413
+
414
+ func testBatchNorm( ) {
415
+ let x = Tensor < Float > ( [
416
+ [ - 1.0474433 , - 0.11914538 , - 0.08634827 , 0.15446888 , 1.0572497 ] ,
417
+ [ 1.5165012 , 0.3753972 , - 0.30856386 , - 0.3100725 , - 1.9584457 ] ,
418
+ [ 0.006384419 , 1.4424847 , 0.91568077 , 0.66328526 , - 1.0794537 ] ,
419
+ [ 1.056803 , 0.14263044 , - 1.8308276 , 0.4189805 , 0.6933893 ] ,
420
+ [ 0.30175626 , - 0.16121633 , - 0.4191958 , - 0.53092813 , - 0.029484272 ] ] )
421
+ let bnLayer = BatchNorm < Float > ( featureCount: 5 , axis: 0 )
422
+ Context . local. learningPhase = . training
423
+ let trainingValue = bnLayer ( x)
424
+ let grad = gradient ( at: x, bnLayer) { $1 ( $0) . squared ( ) . sum ( ) }
425
+ // The expected values and gradients were computed using the following Python code:
426
+ // ```
427
+ // x = tf.constant(
428
+ // [[ -1.0474433, -0.11914538, -0.08634827, 0.15446888, 1.0572497],
429
+ // [ 1.5165012, 0.3753972, -0.30856386, -0.3100725, -1.9584457],
430
+ // [ 0.006384419, 1.4424847, 0.91568077, 0.66328526, -1.0794537],
431
+ // [ 1.056803, 0.14263044, -1.8308276, 0.4189805, 0.6933893],
432
+ // [ 0.30175626, -0.16121633, -0.4191958, -0.53092813, -0.029484272]])
433
+ // scale = tf.reshape(tf.constant([1., 1., 1., 1., 1.]), [5, 1])
434
+ // offset = tf.reshape(tf.constant([0., 0., 0., 0., 0.]), [5, 1])
435
+ // (mean, var) = tf.nn.moments(x, axes=1, keepdims=True)
436
+ // bn = tf.nn.batch_normalization( x, mean, var, offset=offset, scale=scale, variance_epsilon=0.001)
437
+ // scaled = tf.reduce_sum(tf.square(bn))
438
+ // g = tf.gradients(scaled, [x, offset, scale])
439
+ // init = tf.initialize_all_variables()
440
+ // with tf.Session() as sess:
441
+ // sess.run(init)
442
+ // print(sess.run([bn, g]))
443
+ // ```
444
+ let expectedTrainingValue = Tensor < Float > ( [
445
+ [ - 1.5439795 , - 0.16477099 , - 0.11604305 , 0.24174842 , 1.5830451 ] ,
446
+ [ 1.4639764 , 0.45368853 , - 0.15186328 , - 0.15319899 , - 1.6126028 ] ,
447
+ [ - 0.44139984 , 1.2124169 , 0.60574806 , 0.3150888 , - 1.6918538 ] ,
448
+ [ 0.9507547 , 0.04595902 , - 1.9072568 , 0.31947452 , 0.5910686 ] ,
449
+ [ 1.5834246 , 0.02224666 , - 0.8476793 , - 1.2244489 , 0.46645695 ] ] )
450
+
451
+ let expectedInputGradient = Tensor < Float > ( [
452
+ [ - 1.0127544e-02 , - 1.0807812e-03 , - 7.6115131e-04 , 1.5857220e-03 , 1.0383606e-02 ] ,
453
+ [ 2.0323221e-03 , 6.2976527e-04 , - 2.1077941e-04 , - 2.1265696e-04 , - 2.2384699e-03 ] ,
454
+ [ - 1.3483668e-03 , 3.7030075e-03 , 1.8500184e-03 , 9.6232636e-04 , - 5.1673558e-03 ] ,
455
+ [ 1.8438101e-03 , 8.9146197e-05 , - 3.6990643e-03 , 6.1964989e-04 , 1.1463165e-03 ] ,
456
+ [ 1.2142579e-01 , 1.7060755e-03 , - 6.5005139e-02 , - 9.3897656e-02 , 3.5770576e-02 ] ] )
457
+ let expectedScaleGradient = Tensor < Float > ( [ 9.977925 , 9.992161 , 9.986738 , 9.990202 , 9.886292 ] )
458
+ let expectedOffsetGradient = Tensor < Float > ( [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ] )
459
+ assertEqual ( expectedTrainingValue, trainingValue, accuracy: 1e-5 )
460
+ assertEqual ( expectedInputGradient, grad. 0 , accuracy: 1e-5 )
461
+ assertEqual ( expectedScaleGradient, grad. 1 . scale, accuracy: 1e-5 )
462
+ assertEqual ( expectedOffsetGradient, grad. 1 . offset, accuracy: 1e-5 )
463
+ }
464
+
465
+ func testLayerNorm( ) {
466
+ let x = Tensor < Float > ( [
467
+ [ - 1.0474433 , - 0.11914538 , - 0.08634827 , 0.15446888 , 1.0572497 ] ,
468
+ [ 1.5165012 , 0.3753972 , - 0.30856386 , - 0.3100725 , - 1.9584457 ] ,
469
+ [ 0.006384419 , 1.4424847 , 0.91568077 , 0.66328526 , - 1.0794537 ] ,
470
+ [ 1.056803 , 0.14263044 , - 1.8308276 , 0.4189805 , 0.6933893 ] ,
471
+ [ 0.30175626 , - 0.16121633 , - 0.4191958 , - 0.53092813 , - 0.029484272 ] ] )
472
+ let lnLayer = LayerNorm < Float > ( featureCount: 5 , axis: 1 )
473
+ let value = lnLayer ( x)
474
+ let grad = gradient ( at: x, lnLayer) { $1 ( $0) . squared ( ) . sum ( ) }
475
+ // Uses the same values as testBatchNorm() above because LayerNorm with features on axis 1
476
+ // is equivalent to BatchNorm with features on axis 0
477
+ let expectedValue = Tensor < Float > ( [
478
+ [ - 1.5439795 , - 0.16477099 , - 0.11604305 , 0.24174842 , 1.5830451 ] ,
479
+ [ 1.4639764 , 0.45368853 , - 0.15186328 , - 0.15319899 , - 1.6126028 ] ,
480
+ [ - 0.44139984 , 1.2124169 , 0.60574806 , 0.3150888 , - 1.6918538 ] ,
481
+ [ 0.9507547 , 0.04595902 , - 1.9072568 , 0.31947452 , 0.5910686 ] ,
482
+ [ 1.5834246 , 0.02224666 , - 0.8476793 , - 1.2244489 , 0.46645695 ] ] )
483
+
484
+ let expectedInputGradient = Tensor < Float > ( [
485
+ [ - 1.0127544e-02 , - 1.0807812e-03 , - 7.6115131e-04 , 1.5857220e-03 , 1.0383606e-02 ] ,
486
+ [ 2.0323221e-03 , 6.2976527e-04 , - 2.1077941e-04 , - 2.1265696e-04 , - 2.2384699e-03 ] ,
487
+ [ - 1.3483668e-03 , 3.7030075e-03 , 1.8500184e-03 , 9.6232636e-04 , - 5.1673558e-03 ] ,
488
+ [ 1.8438101e-03 , 8.9146197e-05 , - 3.6990643e-03 , 6.1964989e-04 , 1.1463165e-03 ] ,
489
+ [ 1.2142579e-01 , 1.7060755e-03 , - 6.5005139e-02 , - 9.3897656e-02 , 3.5770576e-02 ] ] )
490
+ let expectedScaleGradient = Tensor < Float > ( [ 9.977925 , 9.992161 , 9.986738 , 9.990202 , 9.886292 ] )
491
+ let expectedOffsetGradient = Tensor < Float > ( [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ] )
492
+ assertEqual ( expectedValue, value, accuracy: 1e-5 )
493
+ assertEqual ( expectedInputGradient, grad. 0 , accuracy: 1e-5 )
494
+ assertEqual ( expectedScaleGradient, grad. 1 . scale, accuracy: 1e-5 )
495
+ assertEqual ( expectedOffsetGradient, grad. 1 . offset, accuracy: 1e-5 )
496
+ }
413
497
414
498
static var allTests = [
415
499
( " testSequential " , testSequential) ,
@@ -443,6 +527,8 @@ final class LayerTests: XCTestCase {
443
527
( " testSimpleRNNCell " , testSimpleRNNCell) ,
444
528
( " testDense " , testDense) ,
445
529
( " testRNN " , testRNN) ,
446
- ( " testFunction " , testFunction)
530
+ ( " testFunction " , testFunction) ,
531
+ ( " testBatchNorm " , testBatchNorm) ,
532
+ ( " testLayerNorm " , testLayerNorm)
447
533
]
448
534
}
0 commit comments