Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Added support for the 'expm1' op and its VJP. #146

Merged
merged 11 commits into from
May 31, 2019
15 changes: 15 additions & 0 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,21 @@ internal func _vjpExp<T: TensorFlowFloatingPoint>(
return (value, { v in value * v })
}

/// Computes the exponential of `x - 1` element-wise.
@inlinable
@differentiable(vjp: _vjpExpm1)
public func expm1<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
Raw.expm1(x)
}

@inlinable
internal func _vjpExpm1<T: TensorFlowFloatingPoint>(
_ x: Tensor<T>
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
let y = expm1(x)
return (y, { v in v * y })
}

/// Returns the values of the specified tensor rounded to the nearest integer, element-wise.
@inlinable
@differentiable(vjp: _vjpRound)
Expand Down
7 changes: 7 additions & 0 deletions Tests/TensorFlowTests/OperatorTests/MathTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ final class MathOperatorTests: XCTestCase {
assertEqual(y, log(1 + x), accuracy: 0.0001)
}

func testExpm1() {
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
let y = expm1(x)
assertEqual(y, exp(x - 1), accuracy: 0.0001)
}

func testSign() {
let x = Tensor<Float>([[1, 2, -3, 4, 5], [1, 2, 3, 4, -5]])
let y = sign(x)
Expand Down Expand Up @@ -224,6 +230,7 @@ final class MathOperatorTests: XCTestCase {

static var allTests = [
("testLog1p", testLog1p),
("testExpm1", testExpm1),
("testSign", testSign),
("testReduction", testReduction),
("testArgmax", testArgmax),
Expand Down