Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d264701

Browse files
authored
Added support for 'Tensor.batchGathering(atIndices:)'. (#157)
* Enhanced the 'matmul' wrapper so that it matches the behavior of the Python one. * Added support for the 'log1mexp' op and its VJP. * Added a test. * Update Sources/TensorFlow/Operators/Math.swift Co-Authored-By: Richard Wei <[email protected]> * Removed the need for a general 'Tensor.withoutDerivative()' as Richard suggested. * Addressed Richard's feedback. * Addressed Richard's feedback. * Added one more tests helper. * Minor bug fix. * Added a test for 'log1mexp'. * Added support for 'softplus' and 'logSigmoid'. * Minor tweak. * Added support for 'isFinite', 'isInfinite', and 'isNaN'. * Addressed Richard's feedback. * Addressed Richard's feedback. * Added support for 'gathering' and its VJP. * Added a test for 'gathering'. * Update Sources/TensorFlow/Operators/Basic.swift Co-Authored-By: Richard Wei <[email protected]> * Removed a redundant helper. * Update Tests/TensorFlowTests/OperatorTests/BasicTests.swift Co-Authored-By: Richard Wei <[email protected]> * Added support for a 'Tensor.batchGathering(atIndices:)'. * Fixed some of the tests. * Made the tests pass. * Attempt at making 'log1mexp' differentiable. * Removed 'Tensor.nonZeroIndices()'. * Renamed 'withoutDerivative' to 'noDerivative'. * Added back 'Tensor.nonZeroIndices()'. * Merged upstream changes. * Enabled the 'logSigmoid' test. * Minor edit. * Style edit. * Style edit.
1 parent b32d289 commit d264701

File tree

3 files changed

+178
-85
lines changed

3 files changed

+178
-85
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,49 @@ public extension Tensor {
387387
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
388388
}
389389

390+
/// Returns slices of this tensor at `indices`, while ignoring the first `batchDims` dimensions
391+
/// that correspond to batch dimensions. The gather is performed along the first non-batch
392+
/// dimension.
393+
///
394+
/// Performs similar functionality to `gathering`, except that the resulting tensor shape is
395+
/// now:
396+
/// ```
397+
/// self.shape[..<batchDims] +
398+
/// indices.shape[batchDims...] +
399+
/// self.shape[(batchDims + indices.rank + 1)...]
400+
/// ```
401+
///
402+
/// - Parameters:
403+
/// - indices: Contains the indices to gather.
404+
/// - batchDims: Number of leading batch dimensions to ignore.
405+
///
406+
/// - Precondition: `batchDims` must be less than `indices.rank`.
407+
///
408+
/// - Returns: The gathered tensor.
409+
@inlinable
410+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
411+
func batchGathering(atIndices indices: Tensor<Int32>) -> Tensor {
412+
var batchIndices = indices
413+
var accumulated = Tensor<Int32>(ones: [])
414+
accumulated *= Swift.withoutDerivative(at: shapeTensor) { $0[1] }
415+
let dValue = Swift.withoutDerivative(at: shapeTensor) { $0[0] }
416+
let dIndices = Tensor<Int32>(
417+
rangeFrom: Tensor<Int32>(zeros: []),
418+
to: dValue,
419+
stride: Tensor<Int32>(ones: [])
420+
) * accumulated
421+
let dShape = Tensor<Int32>(concatenating: [
422+
dValue.rankLifted(),
423+
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
424+
batchIndices += dIndices.reshaped(toShape: dShape)
425+
let flatIndices = batchIndices.flattened()
426+
let outerShape = Swift.withoutDerivative(at: shapeTensor) { $0[2...] }
427+
let innerShape = Swift.withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
428+
let flatTensor = reshaped(toShape: innerShape.rankLifted().concatenated(with: outerShape))
429+
let flatResult = flatTensor.gathering(atIndices: flatIndices)
430+
return flatResult.reshaped(toShape: indices.shapeTensor.concatenated(with: outerShape))
431+
}
432+
390433
/// Returns a tensor by gathering the values after applying the provided boolean mask to the input.
391434
///
392435
/// For example:

Sources/TensorFlow/Operators/Math.swift

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -784,23 +784,6 @@ public func rsqrt<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
784784
Raw.rsqrt(x)
785785
}
786786

787-
/// Returns the cosine similarity between `x` and `y`.
788-
@differentiable
789-
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
790-
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
791-
) -> Tensor<Scalar> {
792-
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
793-
}
794-
795-
/// Returns the cosine distance between `x` and `y`. Cosine distance is defined as
796-
/// `1 - cosineSimilarity(x, y)`.
797-
@differentiable
798-
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
799-
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
800-
) -> Tensor<Scalar> {
801-
1 - cosineSimilarity(x, y)
802-
}
803-
804787
@inlinable
805788
internal func _vjpRsqrt<T: TensorFlowFloatingPoint>(
806789
_ x: Tensor<T>
@@ -925,6 +908,14 @@ internal func _vjpSigmoid<T: TensorFlowFloatingPoint>(
925908
(sigmoid(x), { v in Raw.sigmoidGrad(x, dy: v) })
926909
}
927910

911+
/// Returns the log-sigmoid of the specified tensor element-wise. Specifically,
912+
/// `y = log(1 / (1 + exp(-x)))`. For numerical stability, we use `y = -softplus(-x)`.
913+
@inlinable
914+
@differentiable
915+
public func logSigmoid<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
916+
-softplus(-x)
917+
}
918+
928919
/// Returns the softplus of the specified tensor element-wise.
929920
/// Specifically, computes `log(exp(features) + 1)`.
930921
@inlinable
@@ -1016,6 +1007,24 @@ func _vjpElu<T: TensorFlowFloatingPoint>(
10161007
return (y, { v in Raw.eluGrad(gradients: v, outputs: y) })
10171008
}
10181009

1010+
/// Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.
1011+
///
1012+
/// Specifically, `gelu` approximates `xP(X <= x)`, where `P(X <= x)` is the Standard Gaussian
1013+
/// cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].
1014+
///
1015+
/// See [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
1016+
@inlinable
1017+
@differentiable
1018+
public func gelu<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1019+
let ratio = Tensor<T>(0.7978845608) // An approximation of √(2/π).
1020+
// An approximation of the Gauss error function.
1021+
// NOTE: This is needed because the compiler otherwise gives an "unable to type-check this
1022+
// in reasonable time" error when the below expressions are written on a single line.
1023+
let approximateErf = tanh(ratio * (x + 0.044715 * pow(x, 3)))
1024+
let cdf = 0.5 * (1.0 + approximateErf)
1025+
return x * cdf
1026+
}
1027+
10191028
/// Returns a tensor by applying the leaky ReLU activation function
10201029
/// to the specified tensor element-wise.
10211030
/// Specifically, computes `max(x, x * alpha)`.
@@ -1053,22 +1062,15 @@ func _vjpRelu<T: TensorFlowFloatingPoint>(
10531062
(relu(x), { v in Tensor(x .> 0) * v })
10541063
}
10551064

1056-
/// Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.
1057-
///
1058-
/// Specifically, `gelu` approximates `xP(X <= x)`, where `P(X <= x)` is the Standard Gaussian
1059-
/// cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].
1060-
///
1061-
/// See [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
1062-
@inlinable
1063-
@differentiable
1064-
public func gelu<T: TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
1065-
let ratio = Tensor<T>(0.7978845608) // An approximation of √(2/π).
1066-
// An approximation of the Gauss error function.
1067-
// NOTE: This is needed because the compiler otherwise gives an "unable to type-check this
1068-
// in reasonable time" error when the below expressions are written on a single line.
1069-
let approximateErf = tanh(ratio * (x + 0.044715 * pow(x, 3)))
1070-
let cdf = 0.5 * (1.0 + approximateErf)
1071-
return x * cdf
1065+
public extension Tensor where Scalar: TensorFlowFloatingPoint {
1066+
/// Returns a boolean tensor indicating which elements of `x` are finite.
1067+
@inlinable var isFinite: Tensor<Bool> { Raw.isFinite(self) }
1068+
1069+
/// Returns a boolean tensor indicating which elements of `x` are infinite.
1070+
@inlinable var isInfinite: Tensor<Bool> { Raw.isInf(self) }
1071+
1072+
/// Returns a boolean tensor indicating which elements of `x` are NaN-valued.
1073+
@inlinable var isNaN: Tensor<Bool> { Raw.isNan(self) }
10721074
}
10731075

10741076
//===------------------------------------------------------------------------------------------===//
@@ -1202,6 +1204,23 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
12021204
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
12031205
}
12041206

1207+
/// Returns the cosine similarity between `x` and `y`.
1208+
@differentiable
1209+
public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
1210+
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
1211+
) -> Tensor<Scalar> {
1212+
(x * y).sum() / (sqrt(x.squared().sum()) * sqrt(y.squared().sum()))
1213+
}
1214+
1215+
/// Returns the cosine distance between `x` and `y`. Cosine distance is defined as
1216+
/// `1 - cosineSimilarity(x, y)`.
1217+
@differentiable
1218+
public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
1219+
_ x: Tensor<Scalar>, _ y: Tensor<Scalar>
1220+
) -> Tensor<Scalar> {
1221+
1 - cosineSimilarity(x, y)
1222+
}
1223+
12051224
//===------------------------------------------------------------------------------------------===//
12061225
// Selection Functions
12071226
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -60,34 +60,92 @@ final class MathOperatorTests: XCTestCase {
6060
}
6161

6262
func testLog1p() {
63-
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
63+
let x = Tensor<Float>([1, 2, 3, 4, 5])
6464
let y = log1p(x)
65-
assertEqual(y, log(1 + x), accuracy: 0.0001)
65+
let expectedY = Tensor<Float>([0.69315, 1.09861, 1.38629, 1.60944, 1.79176])
66+
assertEqual(y, expectedY, accuracy: 0.0001)
6667
}
6768

68-
func testCosineSimilarity() {
69-
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
70-
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
71-
let z = cosineSimilarity(x, y)
72-
let output: Float = 1.0
73-
XCTAssertEqual(z, Tensor(output))
74-
}
75-
76-
// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.
77-
/*
7869
func testExpm1() {
79-
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
70+
let x = Tensor<Float>([1, 2, 3, 4, 5])
8071
let y = expm1(x)
81-
assertEqual(y, exp(x - 1), accuracy: 0.0001)
72+
let expectedY = Tensor<Float>([1.71828, 6.38906, 19.08554, 53.59815, 147.41316])
73+
assertEqual(y, expectedY, accuracy: 0.0001)
8274
}
83-
*/
8475

8576
func testSign() {
8677
let x = Tensor<Float>([[1, 2, -3, 4, 5], [1, 2, 3, 4, -5]])
8778
let y = sign(x)
8879
XCTAssertEqual(y, Tensor<Float>([[1, 1, -1, 1, 1], [1, 1, 1, 1, -1]]))
8980
}
9081

82+
func testLogSigmoid() {
83+
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
84+
let y = logSigmoid(x)
85+
assertEqual(y, log(sigmoid(x)), accuracy: 0.0001)
86+
}
87+
88+
func testSoftplus() {
89+
let x = Tensor<Float>([1.0, 2.0, 3.0])
90+
let y = softplus(x)
91+
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
92+
XCTAssertEqual(y, expected)
93+
}
94+
95+
func testSoftsign() {
96+
let x = Tensor<Float>([1.0, 4.0, 3.0])
97+
let y = softsign(x)
98+
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
99+
XCTAssertEqual(y, expected)
100+
}
101+
102+
func testElu() {
103+
let x = Tensor<Float>([-1.0, 2.0, 3.0])
104+
let y = elu(x)
105+
let expected = Tensor<Float>([-0.63212055, 2, 3])
106+
XCTAssertEqual(y, expected)
107+
}
108+
109+
func testGelu() {
110+
let x = Tensor<Float>([2.0, 1.0, 7.0])
111+
let y = gelu(x)
112+
let expected = Tensor<Float>([1.95459769, 0.84119199, 7.0])
113+
XCTAssertEqual(y, expected)
114+
}
115+
116+
func testLeakyRelu() {
117+
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
118+
let y = leakyRelu(x, alpha: 0.4)
119+
let expected = Tensor<Float>([-0.4, 2, 3])
120+
XCTAssertEqual(y, expected)
121+
}
122+
123+
func testIsFinite() {
124+
let x = Tensor<Float>([1, 2, 3, 4, -Float.infinity])
125+
let y = x.isFinite
126+
XCTAssertEqual(y, Tensor([true, true, true, true, false]))
127+
}
128+
129+
func testIsInfinite() {
130+
let x = Tensor<Float>([1, 2, 3, 4, log(0.0)])
131+
let y = x.isInfinite
132+
XCTAssertEqual(y, Tensor([false, false, false, false, true]))
133+
}
134+
135+
func testIsNaN() {
136+
let x = Tensor<Float>([1, 2, 3, 4, log(-5.0)])
137+
let y = x.isNaN
138+
XCTAssertEqual(y, Tensor([false, false, false, false, true]))
139+
}
140+
141+
func testCosineSimilarity() {
142+
let x = Tensor<Float>([1, 2, 3, 4, 5, 6, 7, 8])
143+
let y = Tensor<Float>([0.5, 1, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
144+
let z = cosineSimilarity(x, y)
145+
let output: Float = 1.0
146+
XCTAssertEqual(z, Tensor(output))
147+
}
148+
91149
func testReduction() {
92150
// 2 x 5
93151
let x = Tensor<Float>([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
@@ -222,41 +280,6 @@ final class MathOperatorTests: XCTestCase {
222280
XCTAssertEqual(result.scalars, [12.5, 6.5])
223281
}
224282

225-
func testSoftplus() {
226-
let x = Tensor<Float>([1.0, 2.0, 3.0])
227-
let y = softplus(x)
228-
let expected = Tensor<Float>([1.3132616, 2.126928, 3.0485873])
229-
XCTAssertEqual(y, expected)
230-
}
231-
232-
func testSoftsign() {
233-
let x = Tensor<Float>([1.0, 4.0, 3.0])
234-
let y = softsign(x)
235-
let expected = Tensor<Float>([0.5 , 0.8 , 0.75])
236-
XCTAssertEqual(y, expected)
237-
}
238-
239-
func testElu() {
240-
let x = Tensor<Float>([-1.0, 2.0, 3.0])
241-
let y = elu(x)
242-
let expected = Tensor<Float>([-0.63212055, 2, 3])
243-
XCTAssertEqual(y, expected)
244-
}
245-
246-
func testGelu() {
247-
let x = Tensor<Float>([2.0, 1.0, 7.0])
248-
let y = gelu(x)
249-
let expected = Tensor<Float>([1.95459769, 0.84119199, 7.0])
250-
XCTAssertEqual(y, expected)
251-
}
252-
253-
func testLeakyRelu() {
254-
let x = Tensor<Float>([[-1.0, 2.0, 3.0]])
255-
let y = leakyRelu(x, alpha: 0.4)
256-
let expected = Tensor<Float>([-0.4, 2, 3])
257-
XCTAssertEqual(y, expected)
258-
}
259-
260283
func testXORInference() {
261284
func xor(_ x: Float, _ y: Float) -> Float {
262285
let x = Tensor<Float>([x, y]).reshaped(to: [1, 2])
@@ -318,18 +341,26 @@ final class MathOperatorTests: XCTestCase {
318341
}
319342

320343
static var allTests = [
344+
("testElementaryFunctions", testElementaryFunctions),
321345
("testLog1p", testLog1p),
322-
// FIXME(https://bugs.swift.org/browse/TF-543): Disable failing test.
323-
// ("testExpm1", testExpm1),
346+
("testExpm1", testExpm1),
324347
("testSign", testSign),
348+
("testLogSigmoid", testLogSigmoid),
325349
("testReduction", testReduction),
326350
("testCosineSimilarity", testCosineSimilarity),
327351
("testElu",testElu),
328352
("testGelu", testGelu),
329353
("testArgmax", testArgmax),
330354
("testSoftplus", testSoftplus),
331355
("testSoftsign", testSoftsign),
356+
("testElu",testElu),
332357
("testLeakyRelu", testLeakyRelu),
358+
("testIsFinite", testIsFinite),
359+
("testIsInfinite", testIsInfinite),
360+
("testIsNaN", testIsNaN),
361+
("testCosineSimilarity", testCosineSimilarity),
362+
("testReduction", testReduction),
363+
("testArgmax", testArgmax),
333364
("testCeilAndFloor", testCeilAndFloor),
334365
("testSimpleMath", testSimpleMath),
335366
("testStandardDeviation", testStandardDeviation),

0 commit comments

Comments
 (0)