Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 190cf87

Browse files
Shashi456dan-zheng
authored andcommitted
[Linear Algebra] Add det and slogdet (#604)
1 parent 1ed2bc7 commit 190cf87

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

Sources/TensorFlow/Operators/LinearAlgebra.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,31 @@ public func trace<T: TensorFlowNumeric>(_ matrix: Tensor<T>) -> Tensor<T> {
177177
return matrix.diagonalPart().sum(squeezingAxes: -1)
178178
}
179179

180+
/// Computes the determinant of an optionally batched matrix.
181+
///
182+
/// - Parameter matrix: A tensor of shape `[..., M, M]`.
183+
/// - Returns: A tensor containing the determinants of all input submatrices.
184+
@inlinable
185+
func det<T: TensorFlowFloatingPoint>(_ matrix: Tensor<T>) -> Tensor<T> {
186+
_Raw.matrixDeterminant(matrix)
187+
}
188+
189+
/// Computes the sign and the natural logarithm of the absolute value of the determinant of an
190+
/// optionally batched square matrix.
191+
///
192+
/// - Parameter matrix: A tensor of shape `[..., N, M, M]`.
193+
/// - Returns:
194+
/// - sign: A tensor with shape `[N]`, representing the signs of the natural logarithms of the
195+
/// determinants of input submatrices.
196+
/// - logAbsDeterminant: A tensor with shape `[N]`, representing the natural logarithms of the
197+
/// absolute values of the determinants of input submatrices.
198+
@inlinable
199+
func slogdet<T: TensorFlowFloatingPoint>(_ matrix: Tensor<T>) -> (
200+
sign: Tensor<T>, logAbsDeterminant: Tensor<T>
201+
) {
202+
_Raw.logMatrixDeterminant(matrix)
203+
}
204+
180205
/// Computes the natural logarithm of the determinant of a hermitian positive definite matrix.
181206
///
182207
/// - Parameter matrix: A tensor of shape `[..., M, N]`.

Tests/TensorFlowTests/OperatorTests/LinearAlgebraTests.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,34 @@ final class LinearAlgebraTests: XCTestCase {
115115
assertEqual(computedGradient, expectedGradient, accuracy: 1e-16)
116116
}
117117

118+
func testDet() {
119+
var matrix = Tensor<Float>(shape: [1, 4, 4], scalars: (0..<16).map(Float.init))
120+
var computedDet = det(matrix)
121+
var expectedDet = Tensor<Float>([0])
122+
XCTAssertEqual(computedDet, expectedDet)
123+
124+
matrix = Tensor<Float>(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
125+
computedDet = det(matrix)
126+
expectedDet = Tensor<Float>([[-2.0, -2.0], [-2.0, -2.0]])
127+
assertEqual(computedDet, expectedDet, accuracy: 1e-5)
128+
}
129+
130+
func testSlogdet() {
131+
var input = Tensor<Float>(shape: [1, 2, 2], scalars: (0..<4).map(Float.init))
132+
var expectedSigns = Tensor<Float>([-1])
133+
var expectedLogs = Tensor<Float>([0.6931472])
134+
var (computedSigns, computedLogs) = slogdet(input)
135+
XCTAssertEqual(computedSigns, expectedSigns)
136+
XCTAssertEqual(computedLogs, expectedLogs)
137+
138+
input = Tensor<Float>(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
139+
expectedSigns = Tensor<Float>([[-1.0, -1.0], [-1.0, -1.0]])
140+
expectedLogs = Tensor<Float>([[0.6931472, 0.6931462], [0.6931462, 0.6931435]])
141+
(computedSigns, computedLogs) = slogdet(input)
142+
XCTAssertEqual(computedSigns, expectedSigns)
143+
XCTAssertEqual(computedLogs, expectedLogs)
144+
}
145+
118146
func testLogdet() {
119147
let input = Tensor<Float>([[[6.0, 4.0], [4.0, 6.0]], [[2.0, 6.0], [6.0, 20.0]]])
120148
let expected = Tensor<Float>([2.9957323, 1.3862934])
@@ -148,6 +176,8 @@ final class LinearAlgebraTests: XCTestCase {
148176
("testSVD", testSVD),
149177
("testTrace", testTrace),
150178
("testTraceGradient", testTraceGradient),
179+
("testDet", testDet),
180+
("testSlogdet", testSlogdet),
151181
("testLogdet", testLogdet),
152182
("testLogdetGradient", testLogdetGradient)
153183
]

Tests/TensorFlowTests/OperatorTests/MatrixTests.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import XCTest
1616
@testable import TensorFlow
1717

18-
1918
final class MatrixTests: XCTestCase {
2019
func testDiagonalPart() {
2120
// Test on a matrix.

0 commit comments

Comments
 (0)