@@ -484,6 +484,36 @@ final class MathOperatorTests: XCTestCase {
484
484
XCTAssertEqual ( Double ( prediction. scalars [ 0 ] ) , 0.816997 , accuracy: 0.0001 )
485
485
}
486
486
487
+ func testCholesky( ) {
488
+ let shapes = [ [ 3 , 3 ] , [ 4 , 2 , 2 ] , [ 2 , 1 , 16 , 16 ] ]
489
+ let permutations = [ [ 1 , 0 ] , [ 0 , 2 , 1 ] , [ 0 , 1 , 3 , 2 ] ] // To avoid permuting batch dimensions.
490
+ for (shape, permutation) in zip ( shapes, permutations) {
491
+ let a = Tensor < Float > ( randomNormal: TensorShape ( shape) )
492
+ let x = matmul ( a, a. transposed ( permutation: permutation) ) // Make `a` positive-definite.
493
+ let l = cholesky ( x)
494
+ let xReconstructed = matmul ( l, l. transposed ( permutation: permutation) )
495
+ assertEqual ( xReconstructed, x, accuracy: 1e-5 )
496
+ }
497
+
498
+ // The expected value of the gradient was computed using the following Python code:
499
+ // ```
500
+ // import tensorflow as tf
501
+ // x = tf.constant([[[6., 4.], [4., 6.]], [[2., 6.], [6., 20.]]])
502
+ // with tf.GradientTape() as tape:
503
+ // tape.watch(x)
504
+ // l = tf.reduce_sum(tf.linalg.cholesky(x))
505
+ // print(tape.gradient(l, x))
506
+ // ```
507
+ let x = Tensor < Float > ( [ [ [ 6 , 4 ] , [ 4 , 6 ] ] , [ [ 2 , 6 ] , [ 6 , 20 ] ] ] )
508
+ let computedGradient = gradient ( at: x) { cholesky ( $0) . sum ( ) }
509
+ let expectedGradient = Tensor < Float > ( [
510
+ [ [ 0.1897575 , 0.02154995 ] ,
511
+ [ 0.02154995 , 0.2738613 ] ] ,
512
+ [ [ 2.4748755 , - 0.7071073 ] ,
513
+ [ - 0.7071073 , 0.3535535 ] ] ] )
514
+ assertEqual ( computedGradient, expectedGradient, accuracy: 1e-5 )
515
+ }
516
+
487
517
func testQRDecompositionApproximation( ) {
488
518
let shapes = [ [ 5 , 8 ] , [ 3 , 4 , 4 ] , [ 3 , 3 , 32 , 64 ] ]
489
519
for shape in shapes {
@@ -570,6 +600,7 @@ final class MathOperatorTests: XCTestCase {
570
600
( " testXWPlusB " , testXWPlusB) ,
571
601
( " testXORInference " , testXORInference) ,
572
602
( " testMLPClassifierStruct " , testMLPClassifierStruct) ,
603
+ ( " testCholesky " , testCholesky) ,
573
604
( " testQRDecompositionApproximation " , testQRDecompositionApproximation) ,
574
605
( " testDiagonalPart " , testDiagonalPart) ,
575
606
( " testBroadcastedAddGradient " , testBroadcastedAddGradient)
0 commit comments