Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit eebb487

Browse files
awavsaeta
authored andcommitted
[Linear Algebra] Trace operator (#586)
Adds a trace operator and adds tests and derivative tests.
1 parent b878b67 commit eebb487

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

Sources/TensorFlow/Operators/LinearAlgebra.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,21 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
106106
}
107107
}
108108

109+
/// Computes the trace of an optionally batched matrix.
110+
/// The trace is the the sum along the main diagonal of each inner-most matrix.
111+
///
112+
/// The input is a tensor with shape `[..., M, N]`.
113+
/// The output is a tensor with shape `[...]`.
114+
///
115+
/// - Parameter matrix: A tensor of shape `[..., M, N]`.
116+
/// - Precondition: `matrix` must be a tensor with shape `[..., M, N]`.
117+
@inlinable
118+
@differentiable(wrt: matrix where T: TensorFlowFloatingPoint)
119+
public func trace<T: TensorFlowNumeric>(_ matrix: Tensor<T>) -> Tensor<T> {
120+
precondition(matrix.rank >= 2, "The tensor must have at least rank 2.")
121+
return matrix.diagonalPart().sum(squeezingAxes: -1)
122+
}
123+
109124
// MARK: - Decompositions
110125

111126
/// Returns the Cholesky decomposition of one or more square matrices.

Tests/TensorFlowTests/OperatorTests/LinearAlgebraTests.swift

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,35 @@ final class LinearAlgebraTests: XCTestCase {
6060
}
6161
}
6262

63+
func testTrace() {
64+
assertEqual(trace(Tensor<Float>(ones: [3, 3])), Tensor(3.0), accuracy: 1e-16)
65+
assertEqual(trace(Tensor<Float>(ones: [5, 6])), Tensor(5.0), accuracy: 1e-16)
66+
let shapes = [[1, 3, 3], [2, 4, 4], [2, 3, 5, 5]]
67+
for shape in shapes {
68+
let x = Tensor<Float>(ones: TensorShape(shape))
69+
let computedTrace = trace(x)
70+
let leadingShape = TensorShape(shape.dropLast(2))
71+
let value = Float(shape.last!)
72+
let expectedTrace = Tensor<Float>(repeating: value, shape: leadingShape)
73+
assertEqual(computedTrace, expectedTrace, accuracy: 1e-16)
74+
}
75+
}
76+
77+
func testTraceGradient() {
78+
let shape: TensorShape = [2, 4, 4]
79+
let scalars = (0..<shape.contiguousSize).map(Float.init)
80+
let x = Tensor<Float>(shape: shape, scalars: scalars)
81+
let computedGradient = gradient(at: x) { (trace($0) * [2.0, 3.0]).sum() }
82+
let a = Tensor<Float>(repeating: 2.0, shape: [4]).diagonal()
83+
let b = Tensor<Float>(repeating: 3.0, shape: [4]).diagonal()
84+
let expectedGradient = Tensor([a, b])
85+
assertEqual(computedGradient, expectedGradient, accuracy: 1e-16)
86+
}
87+
6388
static var allTests = [
6489
("testCholesky", testCholesky),
65-
("testQRDecompositionApproximation", testQRDecompositionApproximation)
90+
("testQRDecompositionApproximation", testQRDecompositionApproximation),
91+
("testTrace", testTrace),
92+
("testTraceGradient", testTraceGradient),
6693
]
6794
}

0 commit comments

Comments
 (0)