Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8e4b552

Browse files
eaplataniossaeta
authored andcommitted
Bug fix for Tensor.logSumExp. (#551)
Fixed a bug in the 'logsumexp' implementation and augment the tests to ensure correctness.
1 parent ee72ba1 commit 8e4b552

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2371,8 +2371,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
23712371
func logSumExp(squeezingAxes axes: Tensor<Int32>) -> Tensor {
23722372
let rawMax = max(alongAxes: axes)
23732373
let offset = withoutDerivative(at: rawMax) { rawMax in
2374-
rawMax.replacing(
2375-
with: Tensor<Scalar>(zerosLike: rawMax),
2374+
Tensor<Scalar>(zerosLike: rawMax).replacing(
2375+
with: rawMax,
23762376
where: rawMax.isFinite)
23772377
}
23782378
let result = TensorFlow.log(TensorFlow.exp(self - offset).sum(squeezingAxes: axes))
@@ -2435,8 +2435,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
24352435
func logSumExp(alongAxes axes: Tensor<Int32>) -> Tensor {
24362436
let rawMax = max(alongAxes: axes)
24372437
let offset = withoutDerivative(at: rawMax) { rawMax in
2438-
rawMax.replacing(
2439-
with: Tensor<Scalar>(zerosLike: rawMax),
2438+
Tensor<Scalar>(zerosLike: rawMax).replacing(
2439+
with: rawMax,
24402440
where: rawMax.isFinite)
24412441
}
24422442
let result = TensorFlow.log(TensorFlow.exp(self - offset).sum(alongAxes: axes))

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ final class MathOperatorTests: XCTestCase {
340340
assertEqual(y0, expectedY0, accuracy: 0.0001)
341341
assertEqual(y1, expectedY1, accuracy: 0.0001)
342342
assertEqual(y2, expectedY2, accuracy: 0.0001)
343+
344+
let xSmall = Tensor<Float>([
345+
-301.9475, -265.2244, -275.77475, -235.28029, -277.2509, -396.6921, -400.01385])
346+
let ySmall = xSmall.logSumExp()
347+
let expectedYSmall = Tensor<Float>(-235.28029)
348+
assertEqual(ySmall, expectedYSmall, accuracy: 0.0001)
343349
}
344350

345351
func testMoments() {

0 commit comments

Comments
 (0)