Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 24fc2ba

Browse files
eaplataniosrxwei
authored andcommitted
Added support for the 'log1p' op and its VJP. (#145)
1 parent 349ae47 commit 24fc2ba

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,20 @@ internal func _vjpLog<T: TensorFlowFloatingPoint>(
408408
return (log(x), { v in v / x })
409409
}
410410

411+
/// Computes the logarithm of `1 + x` element-wise.
412+
@inlinable
413+
@differentiable(vjp: _vjpLog1p)
414+
public func log1p<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
415+
Raw.log1p(x)
416+
}
417+
418+
@inlinable
419+
func _vjpLog1p<T: TensorFlowFloatingPoint>(
420+
_ x: Tensor<T>
421+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
422+
(log1p(x), { v in Raw.xdivy(v, 1 + x) })
423+
}
424+
411425
/// Computes `sin` of the specified tensor element-wise.
412426
@inlinable
413427
@differentiable(vjp: _vjpSin(_:) where T: TensorFlowFloatingPoint)

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ import XCTest
1616
@testable import TensorFlow
1717

1818
final class MathOperatorTests: XCTestCase {
19+
func testLog1p() {
20+
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
21+
let y = log1p(x)
22+
assertEqual(y, log(1 + x), accuracy: 0.0001)
23+
}
24+
1925
func testSign() {
2026
let x = Tensor<Float>([[1, 2, -3, 4, 5], [1, 2, 3, 4, -5]])
2127
let y = sign(x)
@@ -217,6 +223,7 @@ final class MathOperatorTests: XCTestCase {
217223
}
218224

219225
static var allTests = [
226+
("testLog1p", testLog1p),
220227
("testSign", testSign),
221228
("testReduction", testReduction),
222229
("testArgmax", testArgmax),

0 commit comments

Comments
 (0)