This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +26
-0
lines changed
Sources/TensorFlow/Operators
Tests/TensorFlowTests/OperatorTests Expand file tree Collapse file tree 2 files changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -1051,6 +1051,24 @@ func _vjpRelu<T: TensorFlowFloatingPoint>(
1051
1051
( relu ( x) , { v in Tensor ( x .> 0 ) * v } )
1052
1052
}
1053
1053
1054
+ /// Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.
1055
+ ///
1056
+ /// Specifically, `gelu` approximates `xP(X <= x)`, where `P(X <= x)` is the Standard Gaussian
1057
+ /// cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].
1058
+ ///
1059
+ /// See [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
1060
+ @inlinable
1061
+ @differentiable
1062
+ public func gelu< T: TensorFlowFloatingPoint > ( _ x: Tensor < T > ) -> Tensor < T > {
1063
+ let ratio = Tensor < T > ( 0.7978845608 ) // An approximation of √(2/π).
1064
+ // An approximation of the Gauss error function.
1065
+ // NOTE: This is needed because the compiler otherwise gives an "unable to type-check this
1066
+ // in reasonable time" error when the below expressions are written on a single line.
1067
+ let approximateErf = tanh ( ratio * ( x + 0.044715 * pow( x, 3 ) ) )
1068
+ let cdf = 0.5 * ( 1.0 + approximateErf)
1069
+ return x * cdf
1070
+ }
1071
+
1054
1072
//===------------------------------------------------------------------------------------------===//
1055
1073
// Element-wise Binary Math Functions
1056
1074
//===------------------------------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -243,6 +243,13 @@ final class MathOperatorTests: XCTestCase {
243
243
XCTAssertEqual ( y, expected)
244
244
}
245
245
246
+ func testGelu( ) {
247
+ let x = Tensor < Float > ( [ 2.0 , 1.0 , 7.0 ] )
248
+ let y = gelu ( x)
249
+ let expected = Tensor < Float > ( [ 1.95459769 , 0.84119199 , 7.0 ] )
250
+ XCTAssertEqual ( y, expected)
251
+ }
252
+
246
253
func testLeakyRelu( ) {
247
254
let x = Tensor < Float > ( [ [ - 1.0 , 2.0 , 3.0 ] ] )
248
255
let y = leakyRelu ( x, alpha: 0.4 )
@@ -318,6 +325,7 @@ final class MathOperatorTests: XCTestCase {
318
325
( " testReduction " , testReduction) ,
319
326
( " testCosineSimilarity " , testCosineSimilarity) ,
320
327
( " testElu " , testElu) ,
328
+ ( " testGelu " , testGelu) ,
321
329
( " testArgmax " , testArgmax) ,
322
330
( " testSoftplus " , testSoftplus) ,
323
331
( " testSoftsign " , testSoftsign) ,
You can’t perform that action at this time.
0 commit comments