Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Added support for a 'Tensor.gathering(where:)'. #156

Merged
merged 16 commits into from
May 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions Sources/TensorFlow/Operators/Basic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,107 @@ public extension Tensor {
static func ++ (lhs: Tensor, rhs: Tensor) -> Tensor {
return lhs.concatenated(with: rhs)
}

/// Gathers slices of this tensor at `indices` along the `axis` dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comments for non-mutating methods often starts with Returns a ... by <verb>ing ... instead of <Verb>s .... Quite a few doc comments in the library do not follow this guideline yet and should be updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll try to go through them tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #160 .

///
/// For 0-D (scalar) `indices`:
/// ```
/// result[p_0, ..., p_{axis-1},
/// p_{axis + 1}, ..., p_{N-1}] =
/// self[p_0, ..., p_{axis-1},
/// indices,
/// p_{axis + 1}, ..., p_{N-1}]
/// ```
///
/// For 1-D (vector) `indices`:
/// ```
/// result[p_0, ..., p_{axis-1},
/// i,
/// p_{axis + 1}, ..., p_{N-1}] =
/// self[p_0, ..., p_{axis-1},
/// indices[i],
/// p_{axis + 1}, ..., p_{N-1}]
/// ```
///
/// In the general case, produces a resulting tensor where:
/// ```
/// result[p_0, ..., p_{axis-1},
/// i_{batch\_dims}, ..., i_{M-1},
/// p_{axis + 1}, ..., p_{N-1}] =
/// self[p_0, ..., p_{axis-1},
/// indices[i_0, ..., i_{M-1}],
/// p_{axis + 1}, ..., p_{N-1}]
/// ```
/// where `N = self.rank` and `M = indices.rank`.
///
/// The shape of the resulting tensor is:
/// `self.shape[..<axis] + indices.shape + self.shape[(axis + 1)...]`.
///
/// - Note: On CPU, if an out-of-range index is found, an error is thrown. On GPU, if an
/// out-of-range index is found, a 0 is stored in the corresponding output values.
///
/// - Parameters:
/// - indices: Contains the indices to gather at.
/// - axis: Dimension along which to gather. Negative values wrap around.
///
/// - Precondition: `axis` must be in the range `[-rank, rank)`.
///
/// - Returns: The gathered tensor.
@inlinable
@differentiable(wrt: self, vjp: _vjpGathering where Scalar : TensorFlowFloatingPoint)
func gathering(atIndices indices: Tensor<Int32>, alongAxis axis: Int = 0) -> Tensor {
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
}

/// Gathers values from this tensor according to the provided boolean mask.
///
/// For example:
/// ```
/// // 1-D example
/// // tensor is [0, 1, 2, 3]
/// // mask is [true, false, true, false]
/// tensor.gathering(where: mask) // is [0, 2]
///
/// // 2-D example
/// // tensor is [[1, 2], [3, 4], [5, 6]]
/// // mask is [true, false, true]
/// tensor.gathering(where: mask) // is [[1, 2], [5, 6]]
/// ```
///
/// In general, `0 < mask.rank = K <= tensor.rank`, and the `mask`'s shape must match the first
/// K dimensions of the `tensor`'s shape. We then have:
/// `tensor.gathering(where: mask)[i, j1, ..., jd] = tensor[i1, ..., iK, j1, ..., jd]`, where
/// `[i1, ..., iK]` is the `i`th `true` entry of `mask` (row-major order).
///
/// The `axis` could be used with `mask` to indicate the axis to mask from. In that case,
/// `axis + mask.rank <= tensor.rank` and the `mask``'s shape must match the first
/// `axis + mask.rank` dimensions of the `tensor`'s shape.
///
/// - Parameters:
/// - mask: K-D boolean tensor, where `K <= self.rank`.
/// - axis: 0-D integer tensor representing the axis in `self` to mask from, where
/// `K + axis <= self.rank`.
///
/// - Precondition: The `mask` cannot be a scalar: `mask.rank != 0`.
///
/// - Returns: `(self.rank - K + 1)`-dimensional tensor populated by entries in this tensor
/// corresponding to `true` values in `mask`.
@inlinable
// @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func gathering(where mask: Tensor<Bool>, alongAxis axis: Int = 0) -> Tensor {
precondition(mask.rank != 0, "The boolean mask cannot be a scalar.")
// TODO: Remove once control flow AD is supported.
let rank = self.rank
let posAxis = { axis < 0 ? axis + rank : axis }()
let leadingSize = shapeTensor[posAxis ..< posAxis + mask.rank].product().rankLifted()
let reshapedTensor = reshaped(
toShape: Tensor<Int32>(concatenating: [
shapeTensor[..<posAxis],
leadingSize,
shapeTensor[(posAxis + mask.rank)...]]))
let indices = Tensor<Int32>(mask.flattened().nonZeroIndices().squeezingShape(at: 1))
return reshapedTensor.gathering(atIndices: indices, alongAxis: posAxis)
}
}

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
Expand Down Expand Up @@ -375,6 +476,103 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
return (gradients[0], gradients[1])
})
}

@inlinable
func _vjpGathering(
atIndices indices: Tensor<Int32>,
alongAxis axis: Int = 0
) -> (Tensor, (Tensor) -> Tensor) {
let result = gathering(atIndices: indices, alongAxis: axis)
let posAxis = axis < 0 ? axis + rank : axis

// We have a fast gradient implementation for the case when `posAxis == 0`.
if posAxis == 0 {
return (result, { [shape = shapeTensor] v in
let indicesCount = indices.scalarCountTensor.rankLifted()
let valuesShape = Tensor<Int32>(concatenating: [indicesCount, shape[1...]])
let values = v.reshaped(toShape: valuesShape)
let valueIndices = indices.reshaped(toShape: indicesCount)
return Raw.unsortedSegmentSum(
data: values,
segmentIds: valueIndices,
numSegments: shape[0])
})
}

return (result, { [shape = shapeTensor] v in
let indicesSize = Tensor<Int32>(Int32(indices.scalarCount)).rankLifted()
let outerShape = shape[..<posAxis]
let outerSize = outerShape.scalarCount
let innerShape = shape[(posAxis + 1)...]
let innerSize = innerShape.scalarCount
let outerIndices = Tensor<Int32>(rangeFrom: 0, to: Int32(outerSize), stride: 1)
let innerIndices = Tensor<Int32>(
rangeFrom: Int32(outerSize) + 1,
to: Int32(outerSize) + 1 + Int32(innerSize),
stride: 1)
let valuesShape = Tensor<Int32>(concatenating: [outerShape, indicesSize, innerShape])
let values = v.reshaped(toShape: valuesShape)
let valueIndices = indices.reshaped(toShape: indicesSize)

// We need to sum up every slice `values[..., i, ....]` corresponding to
// `tensor[..., indices[i], ...]`. Since `unsortedSegmentSum` does not support an axis
// parameter, we transpose the gather dimension to the front, then use
// `unsortedSegmentSum` to build a `[gatherAxis, outerAxes, innerAxes]` tensor with all
// the gradients affecting each index in `gatherAxis` summed up.
let permutations = Tensor<Int32>(concatenating: [
Tensor<Int32>([Int32(outerSize)]),
outerIndices,
innerIndices])
let transposedValues = values.transposed(withPermutations: permutations)
let gradient = Raw.unsortedSegmentSum(
data: transposedValues,
segmentIds: valueIndices,
numSegments: shape[posAxis])

// Finally, we invert the above transpose operation by moving dimension 0 back to its
// original position.
let inversePermutations = Tensor<Int32>(concatenating: [
outerIndices + 1,
Tensor<Int32>([0]),
innerIndices])
return gradient.transposed(withPermutations: inversePermutations)
})
}
}

public extension Tensor {
/// Returns the locations of non-zero / true values in this tensor.
///
/// The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the
/// number of non-zero elements, and the second dimension (columns) represents the coordinates
/// of the non-zero elements. Keep in mind that the shape of the output tensor can vary
/// depending on how many true values there are in this tensor. Indices are output in row-major
/// order.
///
/// For example:
/// ```
/// // 'input' is [[true, false], [true, false]]
/// // 'input' has 2 true values and so the output has 2 rows.
/// // 'input' has rank of 2, and so the second dimension of the output has size 2.
/// input.nonZeroIndices() // is [[0, 0], [1, 0]]
///
/// // 'input' is [[[ true, false], [ true, false]],
/// // [[false, true], [false, true]],
/// // [[false, false], [false, true]]]
/// // 'input' has 5 true values and so the output has 5 rows.
/// // 'input' has rank 3, and so the second dimension of the output has size 3.
/// input.nonZeroIndices() // is [[0, 0, 0],
/// // [0, 1, 0],
/// // [1, 0, 1],
/// // [1, 1, 1],
/// // [2, 1, 1]]
/// ```
///
/// - Returns: A tensor with shape `(num_true, rank(condition))`.
@inlinable
func nonZeroIndices() -> Tensor<Int64> {
return Raw.where_(self)
}
}

//===------------------------------------------------------------------------------------------===//
Expand Down
6 changes: 0 additions & 6 deletions Tests/TensorFlowTests/Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
import XCTest
@testable import TensorFlow

internal func assertEqual<T: TensorFlowScalar & Equatable>(_ x: Tensor<T>, _ y: Tensor<T>) {
zip(x.scalars, y.scalars).forEach { (x, y) in
XCTAssertEqual(x, y)
}
}

internal func assertEqual<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ y: Tensor<T>, accuracy: T) {
zip(x.scalars, y.scalars).forEach { (x, y) in
XCTAssertEqual(x, y, accuracy: accuracy)
Expand Down
7 changes: 7 additions & 0 deletions Tests/TensorFlowTests/OperatorTests/BasicTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ precedencegroup StridedRangeFormationPrecedence {
}

final class BasicOperatorTests: XCTestCase {
func testGathering() {
let x = Tensor<Float>([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
let y = x.gathering(atIndices: Tensor<Int32>(2), alongAxis: 1)
XCTAssertEqual(y, Tensor<Float>([3.0, 6.0]))
}

func testElementIndexing() {
// NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly
// until send and receive are implemented (without writing a bunch of mini
Expand Down Expand Up @@ -460,6 +466,7 @@ final class BasicOperatorTests: XCTestCase {
}

static var allTests = [
("testGathering", testGathering),
("testElementIndexing", testElementIndexing),
("testElementIndexingAssignment", testElementIndexingAssignment),
("testNestedElementIndexing", testNestedElementIndexing),
Expand Down