Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Enhanced the 'Tensor.batchGathering(atIndices:)' function. #327

Merged
merged 5 commits into from
Jul 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 81 additions & 28 deletions Sources/TensorFlow/Operators/Basic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ infix operator .!=: ComparisonPrecedence
@inlinable
@differentiable(where Scalar: TensorFlowFloatingPoint)
public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
return x
x
}

//===------------------------------------------------------------------------------------------===//
Expand Down Expand Up @@ -390,44 +390,97 @@ public extension Tensor {
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
}

/// Returns slices of this tensor at `indices`, while ignoring the first `batchDims` dimensions
/// that correspond to batch dimensions. The gather is performed along the first non-batch
/// dimension.
/// Returns slices of this tensor at `indices` along the `axis` dimension, while ignoring the
/// first `batchDimensionCount` dimensions that correspond to batch dimensions. The gather is
/// performed along the first non-batch dimension.
///
/// Performs similar functionality to `gathering`, except that the resulting tensor shape is
/// now:
/// ```
/// self.shape[..<batchDims] +
/// indices.shape[batchDims...] +
/// self.shape[(batchDims + indices.rank + 1)...]
/// ```
/// now `shape[..<axis] + indices.shape[batchDimensionCount...] + shape[(axis + 1)...]`.
///
/// - Parameters:
/// - indices: Contains the indices to gather.
/// - batchDims: Number of leading batch dimensions to ignore.
/// - axis: Dimension along which to gather. Negative values wrap around.
/// - batchDimensionCount: Number of leading batch dimensions to ignore.
///
/// - Precondition: `batchDims` must be less than `indices.rank`.
/// - Precondition: `axis` must be in the range `-rank..<rank`, while also being greater than
/// or equal to `batchDimensionCount`.
/// - Precondition: `batchDimensionCount` must be less than `indices.rank`.
///
/// - Returns: The gathered tensor.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func batchGathering<Index: TensorFlowIndex>(atIndices indices: Tensor<Index>) -> Tensor {
var batchIndices = indices
var accumulated = Tensor<Index>(ones: [])
accumulated *= withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
let dValue = withoutDerivative(at: shapeTensor) { $0[0] }
let dIndices = Tensor<Index>(
rangeFrom: Tensor<Index>(zeros: []),
to: Tensor<Index>(dValue),
stride: Tensor<Index>(ones: [])
) * accumulated
let dShape = Tensor<Int32>(concatenating: [
dValue.rankLifted(),
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
batchIndices += dIndices.reshaped(toShape: dShape)
func batchGathering<Index: TensorFlowIndex>(
atIndices indices: Tensor<Index>,
alongAxis axis: Int = 1,
batchDimensionCount: Int = 1
) -> Tensor {
// TODO: precondition(batchDimensionCount >= 0,
// "'batchDimensionCount' must be non-negative.")
// TODO: precondition(batchDimensionCount < indices.rank,
// "'batchDimensionCount' must be less than 'indices.rank'.")
// TODO: precondition(batchDimensionCount < rank,
// "'batchDimensionCount' must be less than the tensor's rank.")

// Handle the axis argument by transposing the axis dimension so that it is the first
// non-batch dimension, recursively calling `batchGathering` with `axis = 0`, and then
// transposing the result to put the pre-axis dimensions before the indices dimensions.
if axis != batchDimensionCount {
// Adjust axis to be positive.
let posAxis = axis < 0 ? axis + rank : axis

// TODO: precondition(posAxis >= 0 && posAxis < rank, "'axis' is out of range.")
// TODO: precondition(batchDimensionCount <= posAxis,
// "'batchDimensionCount' must be less than or equal to 'axis'.")

// Move self[axis] up to self[batchDimensionCount].
let permutation = Tensor<Int32>(concatenating: [
Tensor<Int32>(rangeFrom: 0, to: Int32(batchDimensionCount), stride: 1),
Tensor<Int32>(Int32(axis)).rankLifted(),
Tensor<Int32>(rangeFrom: Int32(batchDimensionCount), to: Int32(posAxis), stride: 1),
Tensor<Int32>(rangeFrom: Int32(axis) + 1, to: Int32(rank), stride: 1)])
let tensor = transposed(withPermutations: permutation)
let result = tensor.batchGathering(
atIndices: indices,
alongAxis: batchDimensionCount,
batchDimensionCount: batchDimensionCount)

// Move the result dimensions corresponding to self[batchDimensionCount..<axis] to
// just before the dimensions corresponding to indices[batchDimensionCount...].
let start = indices.rank + posAxis - batchDimensionCount
let resultPermutation = Tensor<Int32>(concatenating: [
Tensor<Int32>(rangeFrom: 0, to: Int32(batchDimensionCount), stride: 1),
Tensor<Int32>(rangeFrom: Int32(indices.rank), to: Int32(start), stride: 1),
Tensor<Int32>(
rangeFrom: Int32(batchDimensionCount),
to: Int32(indices.rank),
stride: 1),
Tensor<Int32>(rangeFrom: Int32(start), to: Int32(result.rank), stride: 1)])
return result.transposed(withPermutations: resultPermutation)
}

let batchIndices: Tensor<Index> = withoutDerivative(at: {
var batchIndices = indices
var accumulated = Tensor<Index>(ones: [])
for d in (1...batchDimensionCount).reversed() {
accumulated *= Tensor<Index>(self.shapeTensor[d])
let dValue = self.shapeTensor[d - 1]
let dIndices = Tensor<Index>(
rangeFrom: Tensor<Index>(zeros: []),
to: Tensor<Index>(dValue),
stride: Tensor<Index>(ones: [])
) * accumulated
let dShape = Tensor<Int32>(concatenating: [
Tensor<Int32>([Int32](repeating: 1, count: d - 1)),
dValue.rankLifted(),
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
batchIndices += dIndices.reshaped(toShape: dShape)
}
return batchIndices
}())

let flatIndices = batchIndices.flattened()
let outerShape = withoutDerivative(at: shapeTensor) { $0[2...] }
let innerShape = withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
let outerShape = shapeTensor[(batchDimensionCount + 1)...]
let innerShape = shapeTensor[..<(batchDimensionCount + 1)].product(squeezingAxes: [0])
let flatTensor = reshaped(toShape: innerShape.rankLifted().concatenated(with: outerShape))
let flatResult = flatTensor.gathering(atIndices: flatIndices)
return flatResult.reshaped(toShape: indices.shapeTensor.concatenated(with: outerShape))
Expand Down
22 changes: 12 additions & 10 deletions Tests/TensorFlowTests/OperatorTests/BasicTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,24 @@
import XCTest
@testable import TensorFlow

infix operator ++: AdditionPrecedence
infix operator .=

infix operator ..: StridedRangeFormationPrecedence
precedencegroup StridedRangeFormationPrecedence {
associativity: left
higherThan: CastingPrecedence
lowerThan: RangeFormationPrecedence
}

final class BasicOperatorTests: XCTestCase {
func testGathering() {
let x = Tensor<Float>([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
let y = x.gathering(atIndices: Tensor<Int32>(2), alongAxis: 1)
XCTAssertEqual(y, Tensor<Float>([3.0, 6.0]))
}

func testBatchGathering() {
let x = Tensor<Float>([[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]])
let y = x.batchGathering(
atIndices: Tensor<Int32>([1, 0]),
alongAxis: 2,
batchDimensionCount: 2)
XCTAssertEqual(y, Tensor<Float>([2.0, 4.0]))
}

func testPadded() {
let x = Tensor<Float>(ones: [2, 2])
let target = Tensor<Float>([[3, 3, 3], [1, 1, 3], [1, 1, 3]])
Expand Down Expand Up @@ -596,6 +597,7 @@ final class BasicOperatorTests: XCTestCase {

static var allTests = [
("testGathering", testGathering),
("testBatchGathering", testBatchGathering),
("testPadded", testPadded),
("testVJPPadded", testVJPPadded),
("testElementIndexing", testElementIndexing),
Expand Down