Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit db42eb9

Browse files
jon-towsaeta
authored andcommitted
Add support for gelu (#268)
1 parent aea6a4c commit db42eb9

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,24 @@ func _vjpRelu<T: TensorFlowFloatingPoint>(
10511051
(relu(x), { v in Tensor(x .> 0) * v })
10521052
}
10531053

1054+
/// Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.
1055+
///
1056+
/// Specifically, `gelu` approximates `xP(X <= x)`, where `P(X <= x)` is the Standard Gaussian
1057+
/// cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].
1058+
///
1059+
/// See [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
1060+
@inlinable
1061+
@differentiable
1062+
public func gelu<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1063+
let ratio = Tensor<T>(0.7978845608) // An approximation of √(2/π).
1064+
// An approximation of the Gauss error function.
1065+
// NOTE: This is needed because the compiler otherwise gives an "unable to type-check this
1066+
// in reasonable time" error when the below expressions are written on a single line.
1067+
let approximateErf = tanh(ratio * (x + 0.044715 * pow(x, 3)))
1068+
let cdf = 0.5 * (1.0 + approximateErf)
1069+
return x * cdf
1070+
}
1071+
10541072
//===------------------------------------------------------------------------------------------===//
10551073
// Element-wise Binary Math Functions
10561074
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ final class MathOperatorTests: XCTestCase {
243243
XCTAssertEqual(y, expected)
244244
}
245245

246+
func testGelu() {
247+
let x = Tensor<Float>([2.0, 1.0, 7.0])
248+
let y = gelu(x)
249+
let expected = Tensor<Float>([1.95459769, 0.84119199, 7.0])
250+
XCTAssertEqual(y, expected)
251+
}
252+
246253
func testLeakyRelu() {
247254
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
248255
let y = leakyRelu(x, alpha: 0.4)
@@ -318,6 +325,7 @@ final class MathOperatorTests: XCTestCase {
318325
("testReduction", testReduction),
319326
("testCosineSimilarity", testCosineSimilarity),
320327
("testElu",testElu),
328+
("testGelu", testGelu),
321329
("testArgmax", testArgmax),
322330
("testSoftplus", testSoftplus),
323331
("testSoftsign", testSoftsign),

0 commit comments

Comments
 (0)