Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ef3af7a

Browse files
eaplataniosrxwei
andauthored
Added support for the 'log1mexp' op and its VJP. (#147)
* Enhanced the 'matmul' wrapper so that it matches the behavior of the Python one. * Added support for the 'log1mexp' op and its VJP. * Added a test. * Update Sources/TensorFlow/Operators/Math.swift Co-Authored-By: Richard Wei <[email protected]> * Removed the need for a general 'Tensor.withoutDerivative()' as Richard suggested. * Addressed Richard's feedback. * Addressed Richard's feedback. * Added one more tests helper. * Minor bug fix. * Added a test for 'log1mexp'. * Added support for 'softplus' and 'logSigmoid'. * Minor tweak. * Fixed some of the tests. * Made the tests pass. * Attempt at making 'log1mexp' differentiable. * Merged upstream changes. * Enabled the 'logSigmoid' test. * Style edits. * Style edits. * Update Sources/TensorFlow/Operators/Math.swift Co-Authored-By: Richard Wei <[email protected]> * Update Sources/TensorFlow/Operators/Math.swift Co-Authored-By: Richard Wei <[email protected]>
1 parent 1138b08 commit ef3af7a

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,21 @@ func _vjpLog1p<T: TensorFlowFloatingPoint>(
579579
(log1p(x), { v in Raw.xdivy(v, 1 + x) })
580580
}
581581

582+
/// Returns `log(1 - exp(x))` using a numerically stable approach.
583+
///
584+
/// - Note: The approach is shown in Equation 7 of:
585+
/// https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf.
586+
@inlinable
587+
@differentiable
588+
public func log1mexp<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
589+
let isTooSmall = withoutDerivative(at: x) { x in -x .< T(log(2.0)) }
590+
// This `replacing` will ultimately be a no-op because we will not select this code-path
591+
// whenever we use the surrogate `-Tensor(onesLike: x)`.
592+
let ones = withoutDerivative(at: x) { x in Tensor(onesLike: x) }
593+
let xSafe = x.replacing(with: -ones, where: isTooSmall)
594+
return log1p(-exp(xSafe)).replacing(with: log(-expm1(x)), where: isTooSmall)
595+
}
596+
582597
/// Returns the sine of the specified tensor element-wise.
583598
@inlinable
584599
@differentiable(vjp: _vjpSin(_:))
@@ -912,7 +927,7 @@ internal func _vjpSigmoid<T: TensorFlowFloatingPoint>(
912927
}
913928

914929
/// Returns the log-sigmoid of the specified tensor element-wise. Specifically,
915-
/// `y = log(1 / (1 + exp(-x)))`. For numerical stability, we use `y = -softplus(-x)`.
930+
/// `log(1 / (1 + exp(-x)))`. For numerical stability, we use `-softplus(-x)`.
916931
@inlinable
917932
@differentiable
918933
public func logSigmoid<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ final class MathOperatorTests: XCTestCase {
6666
assertEqual(y, expectedY, accuracy: 0.0001)
6767
}
6868

69+
func testLog1mexp() {
70+
let x = Tensor<Float>([-1, -2, -3, -4, -5])
71+
let y = log1mexp(x)
72+
let expectedY = Tensor<Float>([-0.45868, -0.14541, -0.05107, -0.01849, -0.00676])
73+
assertEqual(y, expectedY, accuracy: 0.0001)
74+
}
75+
6976
func testExpm1() {
7077
let x = Tensor<Float>([1, 2, 3, 4, 5])
7178
let y = expm1(x)
@@ -350,19 +357,20 @@ final class MathOperatorTests: XCTestCase {
350357
}
351358

352359
func testBroadcastedAddGradient() {
353-
func foo(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
354-
return (x + y).sum()
355-
}
356-
let x = Tensor<Float>(ones: [1, 2, 1, 4])
357-
let y = Tensor<Float>(ones: [4, 1, 3, 1])
358-
let (dx, dy) = gradient(at: x, y, in: foo)
359-
XCTAssertEqual(x.shape, dx.shape)
360-
XCTAssertEqual(y.shape, dy.shape)
361-
}
360+
func foo(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
361+
return (x + y).sum()
362+
}
363+
let x = Tensor<Float>(ones: [1, 2, 1, 4])
364+
let y = Tensor<Float>(ones: [4, 1, 3, 1])
365+
let (dx, dy) = gradient(at: x, y, in: foo)
366+
XCTAssertEqual(x.shape, dx.shape)
367+
XCTAssertEqual(y.shape, dy.shape)
368+
}
362369

363370
static var allTests = [
364371
("testElementaryFunctions", testElementaryFunctions),
365372
("testLog1p", testLog1p),
373+
("testLog1mexp", testLog1mexp),
366374
("testExpm1", testExpm1),
367375
("testSign", testSign),
368376
("testLogSigmoid", testLogSigmoid),

0 commit comments

Comments
 (0)