Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Added support for 'Tensor.batchGathering(atIndices:)'. #157

Merged
merged 44 commits into from
Jun 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
930cf5f
Enhanced the 'matmul' wrapper so that it matches the behavior of the …
eaplatanios May 30, 2019
a557090
Added support for the 'log1mexp' op and its VJP.
eaplatanios May 30, 2019
9e75132
Added a test.
eaplatanios May 30, 2019
571a301
Update Sources/TensorFlow/Operators/Math.swift
eaplatanios May 30, 2019
2131375
Removed the need for a general 'Tensor.withoutDerivative()' as Richar…
eaplatanios May 30, 2019
1e80a1e
Addressed Richard's feedback.
eaplatanios May 30, 2019
3b60a9e
Addressed Richard's feedback.
eaplatanios May 31, 2019
9ef8db8
Added one more tests helper.
eaplatanios May 31, 2019
561a842
Minor bug fix.
eaplatanios May 31, 2019
670eabf
Merge branch 'matmul' into logm1exp
eaplatanios May 31, 2019
a01f161
Added a test for 'log1mexp'.
eaplatanios May 31, 2019
399aba6
Merge branch 'matmul' into log-sigmoid
eaplatanios May 31, 2019
a30c098
Added support for 'softplus' and 'logSigmoid'.
eaplatanios May 31, 2019
7b7585e
Minor tweak.
eaplatanios May 31, 2019
6c5b2a6
Merge branch 'matmul' into is-finite
eaplatanios May 31, 2019
102fba1
Added support for 'isFinite', 'isInfinite', and 'isNaN'.
eaplatanios May 31, 2019
b874908
Addressed Richard's feedback.
eaplatanios May 31, 2019
001d2de
Addressed Richard's feedback.
eaplatanios May 31, 2019
7c950e8
Merge branch 'matmul' into gathering
eaplatanios May 31, 2019
e129bcb
Added support for 'gathering' and its VJP.
eaplatanios May 31, 2019
f1f9bd8
Added a test for 'gathering'.
eaplatanios May 31, 2019
4bc96c6
Update Sources/TensorFlow/Operators/Basic.swift
eaplatanios May 31, 2019
412451a
Merge remote-tracking branch 'upstream/master' into gathering
eaplatanios May 31, 2019
2876f9b
Removed a redundant helper.
eaplatanios May 31, 2019
6020874
Update Tests/TensorFlowTests/OperatorTests/BasicTests.swift
eaplatanios May 31, 2019
6fd7b36
Added support for a 'Tensor.batchGathering(atIndices:)'.
eaplatanios May 31, 2019
d33db18
Merged upstream changes.
eaplatanios May 31, 2019
a0384e7
Fixed some of the tests.
eaplatanios May 31, 2019
0fbac79
Made the tests pass.
eaplatanios May 31, 2019
9701780
Attempt at making 'log1mexp' differentiable.
eaplatanios Jun 1, 2019
b6a6f65
Removed 'Tensor.nonZeroIndices()'.
eaplatanios Jun 7, 2019
51a834f
Renamed 'withoutDerivative' to 'noDerivative'.
eaplatanios Jun 7, 2019
8195cb2
Merged upstream changes.
eaplatanios Jun 7, 2019
5adbde2
Added back 'Tensor.nonZeroIndices()'.
eaplatanios Jun 7, 2019
795e2cf
Merged upstream changes.
eaplatanios Jun 20, 2019
8869b75
Merged upstream changes.
eaplatanios Jun 20, 2019
2b6a5ba
Enabled the 'logSigmoid' test.
eaplatanios Jun 20, 2019
12c96e8
Merged upstream changes.
eaplatanios Jun 20, 2019
937e285
Merged upstream changes.
eaplatanios Jun 20, 2019
550e5ec
Minor edit.
eaplatanios Jun 20, 2019
556b273
Merged upstream changes.
eaplatanios Jun 20, 2019
8d37218
Merged upstream changes.
eaplatanios Jun 21, 2019
98c6794
Style edit.
eaplatanios Jun 21, 2019
e963797
Style edit.
eaplatanios Jun 21, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions Sources/TensorFlow/Operators/Basic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,49 @@ public extension Tensor {
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
}

/// Returns slices of this tensor at `indices`, while ignoring the first `batchDims` dimensions
/// that correspond to batch dimensions. The gather is performed along the first non-batch
/// dimension.
///
/// Performs similar functionality to `gathering`, except that the resulting tensor shape is
/// now:
/// ```
/// self.shape[..<batchDims] +
/// indices.shape[batchDims...] +
/// self.shape[(batchDims + indices.rank + 1)...]
/// ```
///
/// - Parameters:
/// - indices: Contains the indices to gather.
/// - batchDims: Number of leading batch dimensions to ignore.
///
/// - Precondition: `batchDims` must be less than `indices.rank`.
///
/// - Returns: The gathered tensor.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func batchGathering(atIndices indices: Tensor<Int32>) -> Tensor {
var batchIndices = indices
var accumulated = Tensor<Int32>(ones: [])
accumulated *= Swift.withoutDerivative(at: shapeTensor) { $0[1] }
let dValue = Swift.withoutDerivative(at: shapeTensor) { $0[0] }
let dIndices = Tensor<Int32>(
rangeFrom: Tensor<Int32>(zeros: []),
to: dValue,
stride: Tensor<Int32>(ones: [])
) * accumulated
let dShape = Tensor<Int32>(concatenating: [
dValue.rankLifted(),
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
batchIndices += dIndices.reshaped(toShape: dShape)
let flatIndices = batchIndices.flattened()
let outerShape = Swift.withoutDerivative(at: shapeTensor) { $0[2...] }
let innerShape = Swift.withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
let flatTensor = reshaped(toShape: innerShape.rankLifted().concatenated(with: outerShape))
let flatResult = flatTensor.gathering(atIndices: flatIndices)
return flatResult.reshaped(toShape: indices.shapeTensor.concatenated(with: outerShape))
}

/// Returns a tensor by gathering the values after applying the provided boolean mask to the input.
///
/// For example:
Expand Down
85 changes: 52 additions & 33 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -784,23 +784,6 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
Raw.rsqrt(x)
}

/// Returns the cosine similarity between `x` and `y`.
@differentiable
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
) -> Tensor<Scalar> {
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
}

/// Returns the cosine distance between `x` and `y`. Cosine distance is defined as
/// `1 - cosineSimilarity(x, y)`.
@differentiable
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
) -> Tensor<Scalar> {
1 - cosineSimilarity(x, y)
}

@inlinable
internal func _vjpRsqrt<T: TensorFlowFloatingPoint>(
_ x: Tensor<T>
Expand Down Expand Up @@ -925,6 +908,14 @@ internal func _vjpSigmoid<T: TensorFlowFloatingPoint>(
(sigmoid(x), { v in Raw.sigmoidGrad(x, dy: v) })
}

/// Returns the log-sigmoid of the specified tensor element-wise. Specifically,
/// `y = log(1 / (1 + exp(-x)))`. For numerical stability, we use `y = -softplus(-x)`.
@inlinable
@differentiable
public func logSigmoid<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
-softplus(-x)
}

/// Returns the softplus of the specified tensor element-wise.
/// Specifically, computes `log(exp(features) + 1)`.
@inlinable
Expand Down Expand Up @@ -1016,6 +1007,24 @@ func _vjpElu<T: TensorFlowFloatingPoint>(
return (y, { v in Raw.eluGrad(gradients: v, outputs: y) })
}

/// Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.
///
/// Specifically, `gelu` approximates `xP(X <= x)`, where `P(X <= x)` is the Standard Gaussian
/// cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].
///
/// See [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
@inlinable
@differentiable
public func gelu<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
let ratio = Tensor<T>(0.7978845608) // An approximation of √(2/π).
// An approximation of the Gauss error function.
// NOTE: This is needed because the compiler otherwise gives an "unable to type-check this
// in reasonable time" error when the below expressions are written on a single line.
let approximateErf = tanh(ratio * (x + 0.044715 * pow(x, 3)))
let cdf = 0.5 * (1.0 + approximateErf)
return x * cdf
}

/// Returns a tensor by applying the leaky ReLU activation function
/// to the specified tensor element-wise.
/// Specifically, computes `max(x, x * alpha)`.
Expand Down Expand Up @@ -1053,22 +1062,15 @@ func _vjpRelu<T: TensorFlowFloatingPoint>(
(relu(x), { v in Tensor(x .> 0) * v })
}

/// Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.
///
/// Specifically, `gelu` approximates `xP(X <= x)`, where `P(X <= x)` is the Standard Gaussian
/// cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].
///
/// See [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
@inlinable
@differentiable
public func gelu<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
let ratio = Tensor<T>(0.7978845608) // An approximation of √(2/π).
// An approximation of the Gauss error function.
// NOTE: This is needed because the compiler otherwise gives an "unable to type-check this
// in reasonable time" error when the below expressions are written on a single line.
let approximateErf = tanh(ratio * (x + 0.044715 * pow(x, 3)))
let cdf = 0.5 * (1.0 + approximateErf)
return x * cdf
public extension Tensor where Scalar: TensorFlowFloatingPoint {
/// Returns a boolean tensor indicating which elements of `x` are finite.
@inlinable var isFinite: Tensor<Bool> { Raw.isFinite(self) }

/// Returns a boolean tensor indicating which elements of `x` are infinite.
@inlinable var isInfinite: Tensor<Bool> { Raw.isInf(self) }

/// Returns a boolean tensor indicating which elements of `x` are NaN-valued.
@inlinable var isNaN: Tensor<Bool> { Raw.isNan(self) }
}

//===------------------------------------------------------------------------------------------===//
Expand Down Expand Up @@ -1202,6 +1204,23 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
}

/// Returns the cosine similarity between `x` and `y`.
@differentiable
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
) -> Tensor<Scalar> {
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
}

/// Returns the cosine distance between `x` and `y`. Cosine distance is defined as
/// `1 - cosineSimilarity(x, y)`.
@differentiable
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
) -> Tensor<Scalar> {
1 - cosineSimilarity(x, y)
}

//===------------------------------------------------------------------------------------------===//
// Selection Functions
//===------------------------------------------------------------------------------------------===//
Expand Down
135 changes: 83 additions & 52 deletions Tests/TensorFlowTests/OperatorTests/MathTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,92 @@ final class MathOperatorTests: XCTestCase {
}

func testLog1p() {
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
let x = Tensor<Float>([1, 2, 3, 4, 5])
let y = log1p(x)
assertEqual(y, log(1 + x), accuracy: 0.0001)
let expectedY = Tensor<Float>([0.69315, 1.09861, 1.38629, 1.60944, 1.79176])
assertEqual(y, expectedY, accuracy: 0.0001)
}

func testCosineSimilarity() {
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
let z = cosineSimilarity(x, y)
let output: Float = 1.0
XCTAssertEqual(z, Tensor(output))
}

// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.
/*
func testExpm1() {
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
let x = Tensor<Float>([1, 2, 3, 4, 5])
let y = expm1(x)
assertEqual(y, exp(x - 1), accuracy: 0.0001)
let expectedY = Tensor<Float>([1.71828, 6.38906, 19.08554, 53.59815, 147.41316])
assertEqual(y, expectedY, accuracy: 0.0001)
}
*/

func testSign() {
let x = Tensor<Float>([[1, 2, -3, 4, 5], [1, 2, 3, 4, -5]])
let y = sign(x)
XCTAssertEqual(y, Tensor<Float>([[1, 1, -1, 1, 1], [1, 1, 1, 1, -1]]))
}

func testLogSigmoid() {
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
let y = logSigmoid(x)
assertEqual(y, log(sigmoid(x)), accuracy: 0.0001)
}

func testSoftplus() {
let x = Tensor<Float>([1.0, 2.0, 3.0])
let y = softplus(x)
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
XCTAssertEqual(y, expected)
}

func testSoftsign() {
let x = Tensor<Float>([1.0, 4.0, 3.0])
let y = softsign(x)
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
XCTAssertEqual(y, expected)
}

func testElu() {
let x = Tensor<Float>([-1.0, 2.0, 3.0])
let y = elu(x)
let expected = Tensor<Float>([-0.63212055, 2, 3])
XCTAssertEqual(y, expected)
}

func testGelu() {
let x = Tensor<Float>([2.0, 1.0, 7.0])
let y = gelu(x)
let expected = Tensor<Float>([1.95459769, 0.84119199, 7.0])
XCTAssertEqual(y, expected)
}

func testLeakyRelu() {
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
let y = leakyRelu(x, alpha: 0.4)
let expected = Tensor<Float>([-0.4, 2, 3])
XCTAssertEqual(y, expected)
}

func testIsFinite() {
let x = Tensor<Float>([1, 2, 3, 4, -Float.infinity])
let y = x.isFinite
XCTAssertEqual(y, Tensor([true, true, true, true, false]))
}

func testIsInfinite() {
let x = Tensor<Float>([1, 2, 3, 4, log(0.0)])
let y = x.isInfinite
XCTAssertEqual(y, Tensor([false, false, false, false, true]))
}

func testIsNaN() {
let x = Tensor<Float>([1, 2, 3, 4, log(-5.0)])
let y = x.isNaN
XCTAssertEqual(y, Tensor([false, false, false, false, true]))
}

func testCosineSimilarity() {
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
let z = cosineSimilarity(x, y)
let output: Float = 1.0
XCTAssertEqual(z, Tensor(output))
}

func testReduction() {
// 2 x 5
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
Expand Down Expand Up @@ -222,41 +280,6 @@ final class MathOperatorTests: XCTestCase {
XCTAssertEqual(result.scalars, [12.5, 6.5])
}

func testSoftplus() {
let x = Tensor<Float>([1.0, 2.0, 3.0])
let y = softplus(x)
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
XCTAssertEqual(y, expected)
}

func testSoftsign() {
let x = Tensor<Float>([1.0, 4.0, 3.0])
let y = softsign(x)
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
XCTAssertEqual(y, expected)
}

func testElu() {
let x = Tensor<Float>([-1.0, 2.0, 3.0])
let y = elu(x)
let expected = Tensor<Float>([-0.63212055, 2, 3])
XCTAssertEqual(y, expected)
}

func testGelu() {
let x = Tensor<Float>([2.0, 1.0, 7.0])
let y = gelu(x)
let expected = Tensor<Float>([1.95459769, 0.84119199, 7.0])
XCTAssertEqual(y, expected)
}

func testLeakyRelu() {
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
let y = leakyRelu(x, alpha: 0.4)
let expected = Tensor<Float>([-0.4, 2, 3])
XCTAssertEqual(y, expected)
}

func testXORInference() {
func xor(_ x: Float, _ y: Float) -> Float {
let x = Tensor<Float>([x, y]).reshaped(to: [1, 2])
Expand Down Expand Up @@ -318,18 +341,26 @@ final class MathOperatorTests: XCTestCase {
}

static var allTests = [
("testElementaryFunctions", testElementaryFunctions),
("testLog1p", testLog1p),
// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.
// ("testExpm1", testExpm1),
("testExpm1", testExpm1),
("testSign", testSign),
("testLogSigmoid", testLogSigmoid),
("testReduction", testReduction),
("testCosineSimilarity", testCosineSimilarity),
("testElu",testElu),
("testGelu", testGelu),
("testArgmax", testArgmax),
("testSoftplus", testSoftplus),
("testSoftsign", testSoftsign),
("testElu",testElu),
("testLeakyRelu", testLeakyRelu),
("testIsFinite", testIsFinite),
("testIsInfinite", testIsInfinite),
("testIsNaN", testIsNaN),
("testCosineSimilarity", testCosineSimilarity),
("testReduction", testReduction),
("testArgmax", testArgmax),
("testCeilAndFloor", testCeilAndFloor),
("testSimpleMath", testSimpleMath),
("testStandardDeviation", testStandardDeviation),
Expand Down