Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit daff615

Browse files
jon-towsaeta
authored andcommitted
Add support for ReLU6 activation (#435)
1 parent fe239c5 commit daff615

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,20 @@ func _vjpRelu<T: TensorFlowFloatingPoint>(
11831183
(relu(x), { v in Raw.reluGrad(gradients: v, features: x) })
11841184
}
11851185

1186+
/// Returns a tensor by applying the ReLU6 activation function, namely `min(max(0, x), 6)`.
1187+
@inlinable
1188+
@differentiable(vjp: _vjpRelu6(_:))
1189+
public func relu6<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1190+
Raw.relu6(features: x)
1191+
}
1192+
1193+
@inlinable
1194+
func _vjpRelu6<T: TensorFlowFloatingPoint>(
1195+
_ x: Tensor<T>
1196+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
1197+
(relu6(x), { v in Raw.relu6Grad(gradients: v, features: x)})
1198+
}
1199+
11861200
/// Returns a tensor by applying the leaky ReLU activation function
11871201
/// to the specified tensor element-wise.
11881202
/// Specifically, computes `max(x, x * alpha)`.

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ final class MathOperatorTests: XCTestCase {
153153
XCTAssertEqual(y, expectedY)
154154
}
155155

156+
func testRelu6() {
157+
let x = Tensor<Float>([1.0, -2.0, 3.0, 4.0, 10.0])
158+
let y = relu6(x)
159+
let expectedY = Tensor<Float>([1.0, 0, 3.0, 4.0, 6.0])
160+
XCTAssertEqual(y, expectedY)
161+
}
162+
156163
func testLeakyRelu() {
157164
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
158165
let y = leakyRelu(x, alpha: 0.4)
@@ -536,6 +543,7 @@ final class MathOperatorTests: XCTestCase {
536543
("testElu",testElu),
537544
("testGelu", testGelu),
538545
("testRelu", testRelu),
546+
("testRelu6", testRelu6),
539547
("testLeakyRelu", testLeakyRelu),
540548
("testSelu", testSelu),
541549
("testIsFinite", testIsFinite),

0 commit comments

Comments
 (0)