Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c09de96

Browse files
jon-towdan-zheng
authored andcommitted
[linalg] Add support for cholesky decomposition (#563)
1 parent f24d1a2 commit c09de96

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,6 +2686,33 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
26862686
}
26872687
}
26882688

2689+
/// Returns the Cholesky decomposition of one or more square matrices.
2690+
///
2691+
/// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
2692+
/// form square matrices.
2693+
///
2694+
/// The input has to be symmetric and positive definite. Only the lower-triangular
2695+
/// part of the input will be used for this operation. The upper-triangular part
2696+
/// will not be read.
2697+
///
2698+
/// The output is a tensor of the same shape as the input
2699+
/// containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
2700+
///
2701+
/// - Parameter input: A tensor of shape `[..., M, M]`.
2702+
@inlinable
2703+
@differentiable(vjp: _vjpCholesky)
2704+
public func cholesky<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
2705+
_Raw.cholesky(x)
2706+
}
2707+
2708+
@inlinable
2709+
internal func _vjpCholesky<T: TensorFlowFloatingPoint>(
2710+
_ x: Tensor<T>
2711+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
2712+
let decomposition = cholesky(x)
2713+
return (decomposition, { v in _Raw.choleskyGrad(l: decomposition, grad: v)})
2714+
}
2715+
26892716
public extension Tensor where Scalar: TensorFlowFloatingPoint {
26902717
/// Returns the QR decomposition of each inner matrix in the tensor, a tensor with inner
26912718
/// orthogonal matrices `q` and a tensor with inner upper triangular matrices `r`, such that the

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,36 @@ final class MathOperatorTests: XCTestCase {
484484
XCTAssertEqual(Double(prediction.scalars[0]), 0.816997, accuracy: 0.0001)
485485
}
486486

487+
func testCholesky() {
488+
let shapes = [[3, 3], [4, 2, 2], [2, 1, 16, 16]]
489+
let permutations = [[1, 0], [0, 2, 1], [0, 1, 3, 2]] // To avoid permuting batch dimensions.
490+
for (shape, permutation) in zip(shapes, permutations) {
491+
let a = Tensor<Float>(randomNormal: TensorShape(shape))
492+
let x = matmul(a, a.transposed(permutation: permutation)) // Make `a` positive-definite.
493+
let l = cholesky(x)
494+
let xReconstructed = matmul(l, l.transposed(permutation: permutation))
495+
assertEqual(xReconstructed, x, accuracy: 1e-5)
496+
}
497+
498+
// The expected value of the gradient was computed using the following Python code:
499+
// ```
500+
// import tensorflow as tf
501+
// x = tf.constant([[[6., 4.], [4., 6.]], [[2., 6.], [6., 20.]]])
502+
// with tf.GradientTape() as tape:
503+
// tape.watch(x)
504+
// l = tf.reduce_sum(tf.linalg.cholesky(x))
505+
// print(tape.gradient(l, x))
506+
// ```
507+
let x = Tensor<Float>([[[6, 4], [4, 6]], [[2, 6], [6, 20]]])
508+
let computedGradient = gradient(at: x) { cholesky($0).sum() }
509+
let expectedGradient = Tensor<Float>([
510+
[[0.1897575, 0.02154995],
511+
[0.02154995, 0.2738613]],
512+
[[2.4748755, -0.7071073],
513+
[-0.7071073, 0.3535535]]])
514+
assertEqual(computedGradient, expectedGradient, accuracy: 1e-5)
515+
}
516+
487517
func testQRDecompositionApproximation() {
488518
let shapes = [[5, 8], [3, 4, 4], [3, 3, 32, 64]]
489519
for shape in shapes {
@@ -570,6 +600,7 @@ final class MathOperatorTests: XCTestCase {
570600
("testXWPlusB", testXWPlusB),
571601
("testXORInference", testXORInference),
572602
("testMLPClassifierStruct", testMLPClassifierStruct),
603+
("testCholesky", testCholesky),
573604
("testQRDecompositionApproximation", testQRDecompositionApproximation),
574605
("testDiagonalPart", testDiagonalPart),
575606
("testBroadcastedAddGradient", testBroadcastedAddGradient)

0 commit comments

Comments
 (0)