Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f2acd6c

Browse files
authored
Add the swish function (#646)
1 parent 94ab4f5 commit f2acd6c

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,33 @@ func _vjpSelu<T: TensorFlowFloatingPoint>(
13301330
})
13311331
}
13321332

1333+
/// Returns a tensor by applying the swish activation function, namely
1334+
/// `x * sigmoid(x)`.
1335+
///
1336+
/// Source: "Searching for Activation Functions" (Ramachandran et al. 2017)
1337+
/// https://arxiv.org/abs/1710.05941
1338+
@inlinable
1339+
@differentiable
1340+
public func swish<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1341+
x * sigmoid(x)
1342+
}
1343+
1344+
// Note: A custom vjp function for swish is required to avoid excessive
1345+
// tensor memory consumption due to storing both `x` and `sigmoid(x)` for
1346+
// backprop. This vjp recomputes `sigmoid(x)` during backprop, so that
1347+
// the `sigmoid(x)` expression can be freed during the forward pass.
1348+
@inlinable
1349+
@derivative(of: swish)
1350+
func _vjpSwish<T: TensorFlowFloatingPoint>(
1351+
_ x: Tensor<T>
1352+
) -> (value: Tensor<T>, pullback: (Tensor<T>) -> Tensor<T>) {
1353+
return (swish(x), { v in
1354+
let sigmoidFeatures = sigmoid(x)
1355+
let grad = sigmoidFeatures * (1.0 + x * (1 - sigmoidFeatures))
1356+
return grad * v
1357+
})
1358+
}
1359+
13331360
public extension Tensor where Scalar: TensorFlowFloatingPoint {
13341361
/// Returns a boolean tensor indicating which elements of `x` are finite.
13351362
@inlinable var isFinite: Tensor<Bool> { _Raw.isFinite(self) }

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ final class MathOperatorTests: XCTestCase {
250250
assertEqual(y, expectedY, accuracy: 1e-5)
251251
}
252252

253+
func testSwish() {
254+
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
255+
let y = swish(x)
256+
let expectedY = Tensor<Float>([-0.26894143, 1.761594, 2.8577223])
257+
assertEqual(y, expectedY, accuracy: 1e-5)
258+
}
259+
253260
func testIsFinite() {
254261
let x = Tensor<Float>([1, 2, 3, 4, -Float.infinity])
255262
let y = x.isFinite
@@ -606,6 +613,7 @@ final class MathOperatorTests: XCTestCase {
606613
("testRelu6", testRelu6),
607614
("testLeakyRelu", testLeakyRelu),
608615
("testSelu", testSelu),
616+
("testSwish", testSwish),
609617
("testIsFinite", testIsFinite),
610618
("testIsInfinite", testIsInfinite),
611619
("testIsNaN", testIsNaN),

0 commit comments

Comments
 (0)