Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0e3b3f4

Browse files
eaplataniosrxwei
authored andcommitted
Added support for a 'Tensor.gathering(where:)'. (#156)
1 parent 16374ad commit 0e3b3f4

File tree

3 files changed

+205
-6
lines changed

3 files changed

+205
-6
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,107 @@ public extension Tensor {
335335
static func ++ (lhs: Tensor, rhs: Tensor) -> Tensor {
336336
return lhs.concatenated(with: rhs)
337337
}
338+
339+
/// Gathers slices of this tensor at `indices` along the `axis` dimension.
340+
///
341+
/// For 0-D (scalar) `indices`:
342+
/// ```
343+
/// result[p_0, ..., p_{axis-1},
344+
/// p_{axis + 1}, ..., p_{N-1}] =
345+
/// self[p_0, ..., p_{axis-1},
346+
/// indices,
347+
/// p_{axis + 1}, ..., p_{N-1}]
348+
/// ```
349+
///
350+
/// For 1-D (vector) `indices`:
351+
/// ```
352+
/// result[p_0, ..., p_{axis-1},
353+
/// i,
354+
/// p_{axis + 1}, ..., p_{N-1}] =
355+
/// self[p_0, ..., p_{axis-1},
356+
/// indices[i],
357+
/// p_{axis + 1}, ..., p_{N-1}]
358+
/// ```
359+
///
360+
/// In the general case, produces a resulting tensor where:
361+
/// ```
362+
/// result[p_0, ..., p_{axis-1},
363+
/// i_{batch\_dims}, ..., i_{M-1},
364+
/// p_{axis + 1}, ..., p_{N-1}] =
365+
/// self[p_0, ..., p_{axis-1},
366+
/// indices[i_0, ..., i_{M-1}],
367+
/// p_{axis + 1}, ..., p_{N-1}]
368+
/// ```
369+
/// where `N = self.rank` and `M = indices.rank`.
370+
///
371+
/// The shape of the resulting tensor is:
372+
/// `self.shape[..<axis] + indices.shape + self.shape[(axis + 1)...]`.
373+
///
374+
/// - Note: On CPU, if an out-of-range index is found, an error is thrown. On GPU, if an
375+
/// out-of-range index is found, a 0 is stored in the corresponding output values.
376+
///
377+
/// - Parameters:
378+
/// - indices: Contains the indices to gather at.
379+
/// - axis: Dimension along which to gather. Negative values wrap around.
380+
///
381+
/// - Precondition: `axis` must be in the range `[-rank, rank)`.
382+
///
383+
/// - Returns: The gathered tensor.
384+
@inlinable
385+
@differentiable(wrt: self, vjp: _vjpGathering where Scalar : TensorFlowFloatingPoint)
386+
func gathering(atIndices indices: Tensor<Int32>, alongAxis axis: Int = 0) -> Tensor {
387+
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
388+
}
389+
390+
/// Gathers values from this tensor according to the provided boolean mask.
391+
///
392+
/// For example:
393+
/// ```
394+
/// // 1-D example
395+
/// // tensor is [0, 1, 2, 3]
396+
/// // mask is [true, false, true, false]
397+
/// tensor.gathering(where: mask) // is [0, 2]
398+
///
399+
/// // 2-D example
400+
/// // tensor is [[1, 2], [3, 4], [5, 6]]
401+
/// // mask is [true, false, true]
402+
/// tensor.gathering(where: mask) // is [[1, 2], [5, 6]]
403+
/// ```
404+
///
405+
/// In general, `0 < mask.rank = K <= tensor.rank`, and the `mask`'s shape must match the first
406+
/// K dimensions of the `tensor`'s shape. We then have:
407+
/// `tensor.gathering(where: mask)[i, j1, ..., jd] = tensor[i1, ..., iK, j1, ..., jd]`, where
408+
/// `[i1, ..., iK]` is the `i`th `true` entry of `mask` (row-major order).
409+
///
410+
/// The `axis` could be used with `mask` to indicate the axis to mask from. In that case,
411+
/// `axis + mask.rank <= tensor.rank` and the `mask``'s shape must match the first
412+
/// `axis + mask.rank` dimensions of the `tensor`'s shape.
413+
///
414+
/// - Parameters:
415+
/// - mask: K-D boolean tensor, where `K <= self.rank`.
416+
/// - axis: 0-D integer tensor representing the axis in `self` to mask from, where
417+
/// `K + axis <= self.rank`.
418+
///
419+
/// - Precondition: The `mask` cannot be a scalar: `mask.rank != 0`.
420+
///
421+
/// - Returns: `(self.rank - K + 1)`-dimensional tensor populated by entries in this tensor
422+
/// corresponding to `true` values in `mask`.
423+
@inlinable
424+
// @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
425+
func gathering(where mask: Tensor<Bool>, alongAxis axis: Int = 0) -> Tensor {
426+
precondition(mask.rank != 0, "The boolean mask cannot be a scalar.")
427+
// TODO: Remove once control flow AD is supported.
428+
let rank = self.rank
429+
let posAxis = { axis < 0 ? axis + rank : axis }()
430+
let leadingSize = shapeTensor[posAxis ..< posAxis + mask.rank].product().rankLifted()
431+
let reshapedTensor = reshaped(
432+
toShape: Tensor<Int32>(concatenating: [
433+
shapeTensor[..<posAxis],
434+
leadingSize,
435+
shapeTensor[(posAxis + mask.rank)...]]))
436+
let indices = Tensor<Int32>(mask.flattened().nonZeroIndices().squeezingShape(at: 1))
437+
return reshapedTensor.gathering(atIndices: indices, alongAxis: posAxis)
438+
}
338439
}
339440

340441
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@@ -375,6 +476,103 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
375476
return (gradients[0], gradients[1])
376477
})
377478
}
479+
480+
@inlinable
481+
func _vjpGathering(
482+
atIndices indices: Tensor<Int32>,
483+
alongAxis axis: Int = 0
484+
) -> (Tensor, (Tensor) -> Tensor) {
485+
let result = gathering(atIndices: indices, alongAxis: axis)
486+
let posAxis = axis < 0 ? axis + rank : axis
487+
488+
// We have a fast gradient implementation for the case when `posAxis == 0`.
489+
if posAxis == 0 {
490+
return (result, { [shape = shapeTensor] v in
491+
let indicesCount = indices.scalarCountTensor.rankLifted()
492+
let valuesShape = Tensor<Int32>(concatenating: [indicesCount, shape[1...]])
493+
let values = v.reshaped(toShape: valuesShape)
494+
let valueIndices = indices.reshaped(toShape: indicesCount)
495+
return Raw.unsortedSegmentSum(
496+
data: values,
497+
segmentIds: valueIndices,
498+
numSegments: shape[0])
499+
})
500+
}
501+
502+
return (result, { [shape = shapeTensor] v in
503+
let indicesSize = Tensor<Int32>(Int32(indices.scalarCount)).rankLifted()
504+
let outerShape = shape[..<posAxis]
505+
let outerSize = outerShape.scalarCount
506+
let innerShape = shape[(posAxis + 1)...]
507+
let innerSize = innerShape.scalarCount
508+
let outerIndices = Tensor<Int32>(rangeFrom: 0, to: Int32(outerSize), stride: 1)
509+
let innerIndices = Tensor<Int32>(
510+
rangeFrom: Int32(outerSize) + 1,
511+
to: Int32(outerSize) + 1 + Int32(innerSize),
512+
stride: 1)
513+
let valuesShape = Tensor<Int32>(concatenating: [outerShape, indicesSize, innerShape])
514+
let values = v.reshaped(toShape: valuesShape)
515+
let valueIndices = indices.reshaped(toShape: indicesSize)
516+
517+
// We need to sum up every slice `values[..., i, ....]` corresponding to
518+
// `tensor[..., indices[i], ...]`. Since `unsortedSegmentSum` does not support an axis
519+
// parameter, we transpose the gather dimension to the front, then use
520+
// `unsortedSegmentSum` to build a `[gatherAxis, outerAxes, innerAxes]` tensor with all
521+
// the gradients affecting each index in `gatherAxis` summed up.
522+
let permutations = Tensor<Int32>(concatenating: [
523+
Tensor<Int32>([Int32(outerSize)]),
524+
outerIndices,
525+
innerIndices])
526+
let transposedValues = values.transposed(withPermutations: permutations)
527+
let gradient = Raw.unsortedSegmentSum(
528+
data: transposedValues,
529+
segmentIds: valueIndices,
530+
numSegments: shape[posAxis])
531+
532+
// Finally, we invert the above transpose operation by moving dimension 0 back to its
533+
// original position.
534+
let inversePermutations = Tensor<Int32>(concatenating: [
535+
outerIndices + 1,
536+
Tensor<Int32>([0]),
537+
innerIndices])
538+
return gradient.transposed(withPermutations: inversePermutations)
539+
})
540+
}
541+
}
542+
543+
public extension Tensor {
544+
/// Returns the locations of non-zero / true values in this tensor.
545+
///
546+
/// The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the
547+
/// number of non-zero elements, and the second dimension (columns) represents the coordinates
548+
/// of the non-zero elements. Keep in mind that the shape of the output tensor can vary
549+
/// depending on how many true values there are in this tensor. Indices are output in row-major
550+
/// order.
551+
///
552+
/// For example:
553+
/// ```
554+
/// // 'input' is [[true, false], [true, false]]
555+
/// // 'input' has 2 true values and so the output has 2 rows.
556+
/// // 'input' has rank of 2, and so the second dimension of the output has size 2.
557+
/// input.nonZeroIndices() // is [[0, 0], [1, 0]]
558+
///
559+
/// // 'input' is [[[ true, false], [ true, false]],
560+
/// // [[false, true], [false, true]],
561+
/// // [[false, false], [false, true]]]
562+
/// // 'input' has 5 true values and so the output has 5 rows.
563+
/// // 'input' has rank 3, and so the second dimension of the output has size 3.
564+
/// input.nonZeroIndices() // is [[0, 0, 0],
565+
/// // [0, 1, 0],
566+
/// // [1, 0, 1],
567+
/// // [1, 1, 1],
568+
/// // [2, 1, 1]]
569+
/// ```
570+
///
571+
/// - Returns: A tensor with shape `(num_true, rank(condition))`.
572+
@inlinable
573+
func nonZeroIndices() -> Tensor<Int64> {
574+
return Raw.where_(self)
575+
}
378576
}
379577

380578
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/Helpers.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
import XCTest
1616
@testable import TensorFlow
1717

18-
internal func assertEqual<T: TensorFlowScalar & Equatable>(_ x: Tensor<T>, _ y: Tensor<T>) {
19-
zip(x.scalars, y.scalars).forEach { (x, y) in
20-
XCTAssertEqual(x, y)
21-
}
22-
}
23-
2418
internal func assertEqual<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ y: Tensor<T>, accuracy: T) {
2519
zip(x.scalars, y.scalars).forEach { (x, y) in
2620
XCTAssertEqual(x, y, accuracy: accuracy)

Tests/TensorFlowTests/OperatorTests/BasicTests.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ precedencegroup StridedRangeFormationPrecedence {
2626
}
2727

2828
final class BasicOperatorTests: XCTestCase {
29+
func testGathering() {
30+
let x = Tensor<Float>([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
31+
let y = x.gathering(atIndices: Tensor<Int32>(2), alongAxis: 1)
32+
XCTAssertEqual(y, Tensor<Float>([3.0, 6.0]))
33+
}
34+
2935
func testElementIndexing() {
3036
// NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly
3137
// until send and receive are implemented (without writing a bunch of mini
@@ -460,6 +466,7 @@ final class BasicOperatorTests: XCTestCase {
460466
}
461467

462468
static var allTests = [
469+
("testGathering", testGathering),
463470
("testElementIndexing", testElementIndexing),
464471
("testElementIndexingAssignment", testElementIndexingAssignment),
465472
("testNestedElementIndexing", testNestedElementIndexing),

0 commit comments

Comments
 (0)