Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 349ae47

Browse files
eaplataniosrxwei
authored andcommitted
Added support for the 'sign' op and its VJP. (#144)
1 parent f194110 commit 349ae47

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,19 @@ internal func _vjpFloor<T: TensorFlowFloatingPoint>(
597597
return (floor(x), { _ in Tensor(0).broadcasted(like: x) })
598598
}
599599

600+
@inlinable
601+
@differentiable(vjp: _vjpSign(_:) where T: TensorFlowFloatingPoint)
602+
public func sign<T: Numeric>(_ x: Tensor<T>) -> Tensor<T> {
603+
return Raw.sign(x)
604+
}
605+
606+
@inlinable
607+
internal func _vjpSign<T: TensorFlowFloatingPoint>(
608+
_ x: Tensor<T>
609+
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
610+
return (sign(x), { v in Tensor<T>(zerosLike: x) })
611+
}
612+
600613
/// Computes the sigmoid of the specified tensor element-wise.
601614
/// Specifically, computes `1 / (1 + exp(-x))`.
602615
@inlinable

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ import XCTest
1616
@testable import TensorFlow
1717

1818
final class MathOperatorTests: XCTestCase {
19+
func testSign() {
20+
let x = Tensor<Float>([[1, 2, -3, 4, 5], [1, 2, 3, 4, -5]])
21+
let y = sign(x)
22+
XCTAssertEqual(y, Tensor<Float>([[1, 1, -1, 1, 1], [1, 1, 1, 1, -1]]))
23+
}
24+
1925
func testReduction() {
2026
// 2 x 5
2127
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
@@ -211,6 +217,7 @@ final class MathOperatorTests: XCTestCase {
211217
}
212218

213219
static var allTests = [
220+
("testSign", testSign),
214221
("testReduction", testReduction),
215222
("testArgmax", testArgmax),
216223
("testCeilAndFloor", testCeilAndFloor),

0 commit comments

Comments
 (0)