Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c77b507

Browse files
committed
add mish activation function
1 parent 897bac9 commit c77b507

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,17 @@ public func hardSwish<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
15181518
x * hardSigmoid(x)
15191519
}
15201520

1521+
/// Returns a tensor by applying the mish activation function, namely
1522+
/// `x * tanh(softplus(x))`.
1523+
///
1524+
/// Source: "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
1525+
/// https://arxiv.org/abs/1908.08681
1526+
@inlinable
1527+
@differentiable
1528+
public func mish<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1529+
x * tanh(softplus(x))
1530+
}
1531+
15211532
extension Tensor where Scalar: TensorFlowFloatingPoint {
15221533
/// Returns a boolean tensor indicating which elements of `x` are finite.
15231534
@inlinable public var isFinite: Tensor<Bool> { _Raw.isFinite(self) }

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,13 @@ final class MathOperatorTests: XCTestCase {
284284
let expectedY = Tensor<Float>([0.0, -0.33333334, 0.0, 1.6666666, 4.0])
285285
assertEqual(y, expectedY, accuracy: 1e-5)
286286
}
287+
288+
func testMish() {
289+
let x = Tensor<Float>([-4, -2, 0, 2, 4])
290+
let y = mish(x)
291+
let expectedY = Tensor<Float>([-0.07259174, -0.25250146, 0.0, 1.943959, 3.9974122])
292+
assertEqual(y, expectedY, accuracy: 1e-5)
293+
}
287294

288295
func testIsFinite() {
289296
let x = Tensor<Float>([1, 2, 3, 4, -Float.infinity])
@@ -651,6 +658,7 @@ final class MathOperatorTests: XCTestCase {
651658
("testSwish", testSwish),
652659
("testHardSigmoid", testHardSigmoid),
653660
("testHardSwish", testHardSwish),
661+
("testMish", testMish),
654662
("testIsFinite", testIsFinite),
655663
("testIsInfinite", testIsInfinite),
656664
("testIsNaN", testIsNaN),

0 commit comments

Comments
 (0)