Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2aed9be

Browse files
authored
Enhanced the 'Tensor.batchGathering(atIndices:)' function. (#327)
1 parent fdda63e commit 2aed9be

File tree

2 files changed

+93
-38
lines changed

2 files changed

+93
-38
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ infix operator .!=: ComparisonPrecedence
1818
@inlinable
1919
@differentiable(where Scalar: TensorFlowFloatingPoint)
2020
public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
21-
return x
21+
x
2222
}
2323

2424
//===------------------------------------------------------------------------------------------===//
@@ -390,44 +390,97 @@ public extension Tensor {
390390
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
391391
}
392392

393-
/// Returns slices of this tensor at `indices`, while ignoring the first `batchDims` dimensions
394-
/// that correspond to batch dimensions. The gather is performed along the first non-batch
395-
/// dimension.
393+
/// Returns slices of this tensor at `indices` along the `axis` dimension, while ignoring the
394+
/// first `batchDimensionCount` dimensions that correspond to batch dimensions. The gather is
395+
/// performed along the first non-batch dimension.
396396
///
397397
/// Performs similar functionality to `gathering`, except that the resulting tensor shape is
398-
/// now:
399-
/// ```
400-
/// self.shape[..<batchDims] +
401-
/// indices.shape[batchDims...] +
402-
/// self.shape[(batchDims + indices.rank + 1)...]
403-
/// ```
398+
/// now `shape[..<axis] + indices.shape[batchDimensionCount...] + shape[(axis + 1)...]`.
404399
///
405400
/// - Parameters:
406401
/// - indices: Contains the indices to gather.
407-
/// - batchDims: Number of leading batch dimensions to ignore.
402+
/// - axis: Dimension along which to gather. Negative values wrap around.
403+
/// - batchDimensionCount: Number of leading batch dimensions to ignore.
408404
///
409-
/// - Precondition: `batchDims` must be less than `indices.rank`.
405+
/// - Precondition: `axis` must be in the range `-rank..<rank`, while also being greater than
406+
/// or equal to `batchDimensionCount`.
407+
/// - Precondition: `batchDimensionCount` must be less than `indices.rank`.
410408
///
411409
/// - Returns: The gathered tensor.
412410
@inlinable
413411
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
414-
func batchGathering<Index: TensorFlowIndex>(atIndices indices: Tensor<Index>) -> Tensor {
415-
var batchIndices = indices
416-
var accumulated = Tensor<Index>(ones: [])
417-
accumulated *= withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
418-
let dValue = withoutDerivative(at: shapeTensor) { $0[0] }
419-
let dIndices = Tensor<Index>(
420-
rangeFrom: Tensor<Index>(zeros: []),
421-
to: Tensor<Index>(dValue),
422-
stride: Tensor<Index>(ones: [])
423-
) * accumulated
424-
let dShape = Tensor<Int32>(concatenating: [
425-
dValue.rankLifted(),
426-
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
427-
batchIndices += dIndices.reshaped(toShape: dShape)
412+
func batchGathering<Index: TensorFlowIndex>(
413+
atIndices indices: Tensor<Index>,
414+
alongAxis axis: Int = 1,
415+
batchDimensionCount: Int = 1
416+
) -> Tensor {
417+
// TODO: precondition(batchDimensionCount >= 0,
418+
// "'batchDimensionCount' must be non-negative.")
419+
// TODO: precondition(batchDimensionCount < indices.rank,
420+
// "'batchDimensionCount' must be less than 'indices.rank'.")
421+
// TODO: precondition(batchDimensionCount < rank,
422+
// "'batchDimensionCount' must be less than the tensor's rank.")
423+
424+
// Handle the axis argument by transposing the axis dimension so that it is the first
425+
// non-batch dimension, recursively calling `batchGathering` with `axis = 0`, and then
426+
// transposing the result to put the pre-axis dimensions before the indices dimensions.
427+
if axis != batchDimensionCount {
428+
// Adjust axis to be positive.
429+
let posAxis = axis < 0 ? axis + rank : axis
430+
431+
// TODO: precondition(posAxis >= 0 && posAxis < rank, "'axis' is out of range.")
432+
// TODO: precondition(batchDimensionCount <= posAxis,
433+
// "'batchDimensionCount' must be less than or equal to 'axis'.")
434+
435+
// Move self[axis] up to self[batchDimensionCount].
436+
let permutation = Tensor<Int32>(concatenating: [
437+
Tensor<Int32>(rangeFrom: 0, to: Int32(batchDimensionCount), stride: 1),
438+
Tensor<Int32>(Int32(axis)).rankLifted(),
439+
Tensor<Int32>(rangeFrom: Int32(batchDimensionCount), to: Int32(posAxis), stride: 1),
440+
Tensor<Int32>(rangeFrom: Int32(axis) + 1, to: Int32(rank), stride: 1)])
441+
let tensor = transposed(withPermutations: permutation)
442+
let result = tensor.batchGathering(
443+
atIndices: indices,
444+
alongAxis: batchDimensionCount,
445+
batchDimensionCount: batchDimensionCount)
446+
447+
// Move the result dimensions corresponding to self[batchDimensionCount..<axis] to
448+
// just before the dimensions corresponding to indices[batchDimensionCount...].
449+
let start = indices.rank + posAxis - batchDimensionCount
450+
let resultPermutation = Tensor<Int32>(concatenating: [
451+
Tensor<Int32>(rangeFrom: 0, to: Int32(batchDimensionCount), stride: 1),
452+
Tensor<Int32>(rangeFrom: Int32(indices.rank), to: Int32(start), stride: 1),
453+
Tensor<Int32>(
454+
rangeFrom: Int32(batchDimensionCount),
455+
to: Int32(indices.rank),
456+
stride: 1),
457+
Tensor<Int32>(rangeFrom: Int32(start), to: Int32(result.rank), stride: 1)])
458+
return result.transposed(withPermutations: resultPermutation)
459+
}
460+
461+
let batchIndices: Tensor<Index> = withoutDerivative(at: {
462+
var batchIndices = indices
463+
var accumulated = Tensor<Index>(ones: [])
464+
for d in (1...batchDimensionCount).reversed() {
465+
accumulated *= Tensor<Index>(self.shapeTensor[d])
466+
let dValue = self.shapeTensor[d - 1]
467+
let dIndices = Tensor<Index>(
468+
rangeFrom: Tensor<Index>(zeros: []),
469+
to: Tensor<Index>(dValue),
470+
stride: Tensor<Index>(ones: [])
471+
) * accumulated
472+
let dShape = Tensor<Int32>(concatenating: [
473+
Tensor<Int32>([Int32](repeating: 1, count: d - 1)),
474+
dValue.rankLifted(),
475+
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
476+
batchIndices += dIndices.reshaped(toShape: dShape)
477+
}
478+
return batchIndices
479+
}())
480+
428481
let flatIndices = batchIndices.flattened()
429-
let outerShape = withoutDerivative(at: shapeTensor) { $0[2...] }
430-
let innerShape = withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
482+
let outerShape = shapeTensor[(batchDimensionCount + 1)...]
483+
let innerShape = shapeTensor[..<(batchDimensionCount + 1)].product(squeezingAxes: [0])
431484
let flatTensor = reshaped(toShape: innerShape.rankLifted().concatenated(with: outerShape))
432485
let flatResult = flatTensor.gathering(atIndices: flatIndices)
433486
return flatResult.reshaped(toShape: indices.shapeTensor.concatenated(with: outerShape))

Tests/TensorFlowTests/OperatorTests/BasicTests.swift

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,24 @@
1515
import XCTest
1616
@testable import TensorFlow
1717

18-
infix operator ++: AdditionPrecedence
19-
infix operator .=
20-
21-
infix operator ..: StridedRangeFormationPrecedence
22-
precedencegroup StridedRangeFormationPrecedence {
23-
associativity: left
24-
higherThan: CastingPrecedence
25-
lowerThan: RangeFormationPrecedence
26-
}
27-
2818
final class BasicOperatorTests: XCTestCase {
2919
func testGathering() {
3020
let x = Tensor<Float>([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
3121
let y = x.gathering(atIndices: Tensor<Int32>(2), alongAxis: 1)
3222
XCTAssertEqual(y, Tensor<Float>([3.0, 6.0]))
3323
}
3424

25+
func testBatchGathering() {
26+
let x = Tensor<Float>([[
27+
[1.0, 2.0, 3.0],
28+
[4.0, 5.0, 6.0]]])
29+
let y = x.batchGathering(
30+
atIndices: Tensor<Int32>([1, 0]),
31+
alongAxis: 2,
32+
batchDimensionCount: 2)
33+
XCTAssertEqual(y, Tensor<Float>([2.0, 4.0]))
34+
}
35+
3536
func testPadded() {
3637
let x = Tensor<Float>(ones: [2, 2])
3738
let target = Tensor<Float>([[3, 3, 3], [1, 1, 3], [1, 1, 3]])
@@ -596,6 +597,7 @@ final class BasicOperatorTests: XCTestCase {
596597

597598
static var allTests = [
598599
("testGathering", testGathering),
600+
("testBatchGathering", testBatchGathering),
599601
("testPadded", testPadded),
600602
("testVJPPadded", testVJPPadded),
601603
("testElementIndexing", testElementIndexing),

0 commit comments

Comments
 (0)