Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 484148f

Browse files
Shashi456rxwei
authored andcommitted
Add leaky ReLU activation function (#260)
1 parent 9a59c89 commit 484148f

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ func _vjpLogSoftmax<T: TensorFlowFloatingPoint>(
997997
}
998998

999999
/// Returns a tensor by applying an exponential linear unit.
1000-
/// Specifically, computes `exp(features) - 1` if < 0, `features` otherwise.
1000+
/// Specifically, computes `exp(x) - 1` if < 0, `x` otherwise.
10011001
/// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
10021002
/// ](http://arxiv.org/abs/1511.07289)
10031003
@inlinable
@@ -1014,6 +1014,28 @@ func _vjpElu<T: TensorFlowFloatingPoint>(
10141014
return (y, { v in Raw.eluGrad(gradients: v, outputs: y) })
10151015
}
10161016

1017+
/// Returns a tensor by applying the leaky ReLU activation function
1018+
/// to the specified tensor element-wise.
1019+
/// Specifically, computes `max(x, x * alpha)`.
1020+
@inlinable
1021+
@differentiable(wrt: x, vjp: _vjpLeakyRelu)
1022+
public func leakyRelu<T: TensorFlowFloatingPoint>(
1023+
_ x: Tensor<T>,
1024+
alpha: Double = 0.2
1025+
) -> Tensor<T> {
1026+
Raw.leakyRelu(features: x, alpha: alpha)
1027+
}
1028+
1029+
@inlinable
1030+
func _vjpLeakyRelu<T: TensorFlowFloatingPoint>(
1031+
_ x: Tensor<T>,
1032+
alpha: Double
1033+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
1034+
return (leakyRelu(x, alpha: alpha), { v in
1035+
Raw.leakyReluGrad(gradients: v, features: x, alpha: alpha)
1036+
})
1037+
}
1038+
10171039
/// Computes `relu` of the specified tensor element-wise.
10181040
/// Specifically, computes `max(0, x)`.
10191041
@inlinable

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ final class MathOperatorTests: XCTestCase {
243243
XCTAssertEqual(y, expected)
244244
}
245245

246+
func testLeakyRelu() {
247+
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
248+
let y = leakyRelu(x, alpha: 0.4)
249+
let expected = Tensor<Float>([-0.4, 2, 3])
250+
XCTAssertEqual(y, expected)
251+
}
252+
246253
func testXORInference() {
247254
func xor(_ x: Float, _ y: Float) -> Float {
248255
let x = Tensor<Float>([x, y]).reshaped(to: [1, 2])
@@ -314,6 +321,7 @@ final class MathOperatorTests: XCTestCase {
314321
("testArgmax", testArgmax),
315322
("testSoftplus", testSoftplus),
316323
("testSoftsign", testSoftsign),
324+
("testLeakyRelu", testLeakyRelu),
317325
("testCeilAndFloor", testCeilAndFloor),
318326
("testSimpleMath", testSimpleMath),
319327
("testStandardDeviation", testStandardDeviation),

0 commit comments

Comments
 (0)