Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9a59c89

Browse files
Shashi456rxwei
authored andcommitted
Add exponential linear unit (#252)
1 parent 760fafc commit 9a59c89

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,24 @@ func _vjpLogSoftmax<T: TensorFlowFloatingPoint>(
996996
return (value, { v in v - v.sum(alongAxes: -1) * exp(value) })
997997
}
998998

999+
/// Returns a tensor by applying an exponential linear unit.
1000+
/// Specifically, computes `exp(features) - 1` if < 0, `features` otherwise.
1001+
/// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
1002+
/// ](http://arxiv.org/abs/1511.07289)
1003+
@inlinable
1004+
@differentiable(vjp: _vjpElu)
1005+
public func elu<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1006+
Raw.elu(features: x)
1007+
}
1008+
1009+
@inlinable
1010+
func _vjpElu<T: TensorFlowFloatingPoint>(
1011+
_ x: Tensor<T>
1012+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
1013+
let y = elu(x)
1014+
return (y, { v in Raw.eluGrad(gradients: v, outputs: y) })
1015+
}
1016+
9991017
/// Computes `relu` of the specified tensor element-wise.
10001018
/// Specifically, computes `max(0, x)`.
10011019
@inlinable

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,17 +223,24 @@ final class MathOperatorTests: XCTestCase {
223223
}
224224

225225
func testSoftplus() {
226-
let x = Tensor<Float>([1.0, 2.0, 3.0])
227-
let y = softplus(x)
228-
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
229-
XCTAssertEqual(y, expected)
226+
let x = Tensor<Float>([1.0, 2.0, 3.0])
227+
let y = softplus(x)
228+
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
229+
XCTAssertEqual(y, expected)
230230
}
231231

232232
func testSoftsign() {
233-
let x = Tensor<Float>([1.0, 4.0, 3.0])
234-
let y = softsign(x)
235-
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
236-
XCTAssertEqual(y, expected)
233+
let x = Tensor<Float>([1.0, 4.0, 3.0])
234+
let y = softsign(x)
235+
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
236+
XCTAssertEqual(y, expected)
237+
}
238+
239+
func testElu() {
240+
let x = Tensor<Float>([-1.0, 2.0, 3.0])
241+
let y = elu(x)
242+
let expected = Tensor<Float>([-0.63212055, 2, 3])
243+
XCTAssertEqual(y, expected)
237244
}
238245

239246
func testXORInference() {
@@ -303,6 +310,7 @@ final class MathOperatorTests: XCTestCase {
303310
("testSign", testSign),
304311
("testReduction", testReduction),
305312
("testCosineSimilarity", testCosineSimilarity),
313+
("testElu",testElu),
306314
("testArgmax", testArgmax),
307315
("testSoftplus", testSoftplus),
308316
("testSoftsign", testSoftsign),

0 commit comments

Comments
 (0)