Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5c5f557

Browse files
Shashi456rxwei
authored andcommitted
Add softplus and softsign (#225)
1 parent d80ec3b commit 5c5f557

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,36 @@ internal func _vjpSigmoid<T: TensorFlowFloatingPoint>(
896896
(sigmoid(x), { v in Raw.sigmoidGrad(x, dy: v) })
897897
}
898898

899+
/// Returns the softplus of the specified tensor element-wise.
900+
/// Specifically, computes `log(exp(features) + 1)`.
901+
@inlinable
902+
@differentiable(vjp: _vjpSoftplus)
903+
public func softplus<T: TensorFlowFloatingPoint>(_ features: Tensor<T>) -> Tensor<T> {
904+
Raw.softplus(features: features)
905+
}
906+
907+
@inlinable
908+
internal func _vjpSoftplus<T: TensorFlowFloatingPoint>(
909+
_ features: Tensor<T>
910+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
911+
(softplus(features), { v in Raw.softplusGrad(gradients: v, features: features)})
912+
}
913+
914+
/// Returns the softsign of the specified tensor element-wise.
915+
/// Specifically, computes `features/ (abs(features) + 1)`.
916+
@inlinable
917+
@differentiable(vjp: _vjpSoftsign)
918+
public func softsign<T: TensorFlowFloatingPoint>(_ features: Tensor<T>) -> Tensor<T> {
919+
Raw.softsign(features: features)
920+
}
921+
922+
@inlinable
923+
internal func _vjpSoftsign<T: TensorFlowFloatingPoint>(
924+
_ features: Tensor<T>
925+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
926+
(softsign(features), { v in Raw.softsignGrad(gradients: v, features: features)})
927+
}
928+
899929
/// Computes the softmax of the specified tensor along the last axis.
900930
/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: -1)`.
901931
@inlinable

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ final class MathOperatorTests: XCTestCase {
118118
x.variance(squeezingAxes: 0),
119119
Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]))
120120
XCTAssertEqual(
121-
x.variance(alongAxes: 0),
121+
x.variance(alongAxes: 0),
122122
Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]))
123123
XCTAssertEqual(
124124
x.variance(squeezingAxes: 1),
@@ -214,6 +214,20 @@ final class MathOperatorTests: XCTestCase {
214214
XCTAssertEqual(result.scalars, [12.5, 6.5])
215215
}
216216

217+
func testSoftplus() {
218+
let x = Tensor<Float>([1.0, 2.0, 3.0])
219+
let y = softplus(x)
220+
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
221+
XCTAssertEqual(y, expected)
222+
}
223+
224+
func testSoftsign() {
225+
let x = Tensor<Float>([1.0, 4.0, 3.0])
226+
let y = softsign(x)
227+
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
228+
XCTAssertEqual(y, expected)
229+
}
230+
217231
func testXORInference() {
218232
func xor(_ x: Float, _ y: Float) -> Float {
219233
let x = Tensor<Float>([x, y]).reshaped(to: [1, 2])
@@ -281,6 +295,8 @@ final class MathOperatorTests: XCTestCase {
281295
("testSign", testSign),
282296
("testReduction", testReduction),
283297
("testArgmax", testArgmax),
298+
("testSoftplus", testSoftplus),
299+
("testSoftsign", testSoftsign),
284300
("testCeilAndFloor", testCeilAndFloor),
285301
("testSimpleMath", testSimpleMath),
286302
("testStandardDeviation", testStandardDeviation),

0 commit comments

Comments
 (0)