Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[linalg] Add support for cholesky decomposition #563

Merged
merged 1 commit into from
Nov 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,33 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
}
}

/// Returns the Cholesky decomposition of one or more square matrices.
///
/// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
/// form square matrices.
///
/// The input has to be symmetric and positive definite. Only the lower-triangular
/// part of the input will be used for this operation. The upper-triangular part
/// will not be read.
///
/// The output is a tensor of the same shape as the input
/// containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
///
/// - Parameter input: A tensor of shape `[..., M, M]`.
@inlinable
@differentiable(vjp: _vjpCholesky)
public func cholesky<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work with batching?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it sure does! Check out some of the shapes used in testing.

let shapes = [[3, 3], [4, 2, 2], [2, 1, 16, 16]]

_Raw.cholesky(x)
}

@inlinable
internal func _vjpCholesky<T: TensorFlowFloatingPoint>(
_ x: Tensor<T>
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
let decomposition = cholesky(x)
return (decomposition, { v in _Raw.choleskyGrad(l: decomposition, grad: v)})
}

public extension Tensor where Scalar: TensorFlowFloatingPoint {
/// Returns the QR decomposition of each inner matrix in the tensor, a tensor with inner
/// orthogonal matrices `q` and a tensor with inner upper triangular matrices `r`, such that the
Expand Down
31 changes: 31 additions & 0 deletions Tests/TensorFlowTests/OperatorTests/MathTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,36 @@ final class MathOperatorTests: XCTestCase {
XCTAssertEqual(Double(prediction.scalars[0]), 0.816997, accuracy: 0.0001)
}

func testCholesky() {
let shapes = [[3, 3], [4, 2, 2], [2, 1, 16, 16]]
let permutations = [[1, 0], [0, 2, 1], [0, 1, 3, 2]] // To avoid permuting batch dimensions.
for (shape, permutation) in zip(shapes, permutations) {
let a = Tensor<Float>(randomNormal: TensorShape(shape))
let x = matmul(a, a.transposed(permutation: permutation)) // Make `a` positive-definite.
let l = cholesky(x)
let xReconstructed = matmul(l, l.transposed(permutation: permutation))
assertEqual(xReconstructed, x, accuracy: 1e-5)
}

// The expected value of the gradient was computed using the following Python code:
// ```
// import tensorflow as tf
// x = tf.constant([[[6., 4.], [4., 6.]], [[2., 6.], [6., 20.]]])
// with tf.GradientTape() as tape:
// tape.watch(x)
// l = tf.reduce_sum(tf.linalg.cholesky(x))
// print(tape.gradient(l, x))
// ```
let x = Tensor<Float>([[[6, 4], [4, 6]], [[2, 6], [6, 20]]])
let computedGradient = gradient(at: x) { cholesky($0).sum() }
let expectedGradient = Tensor<Float>([
[[0.1897575, 0.02154995],
[0.02154995, 0.2738613]],
[[2.4748755, -0.7071073],
[-0.7071073, 0.3535535]]])
assertEqual(computedGradient, expectedGradient, accuracy: 1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accuracy level: is that a rtol or atol?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in any case it is too high

}

func testQRDecompositionApproximation() {
let shapes = [[5, 8], [3, 4, 4], [3, 3, 32, 64]]
for shape in shapes {
Expand Down Expand Up @@ -570,6 +600,7 @@ final class MathOperatorTests: XCTestCase {
("testXWPlusB", testXWPlusB),
("testXORInference", testXORInference),
("testMLPClassifierStruct", testMLPClassifierStruct),
("testCholesky", testCholesky),
("testQRDecompositionApproximation", testQRDecompositionApproximation),
("testDiagonalPart", testDiagonalPart),
("testBroadcastedAddGradient", testBroadcastedAddGradient)
Expand Down