@@ -18,7 +18,7 @@ infix operator .!=: ComparisonPrecedence
18
18
@inlinable
19
19
@differentiable ( where Scalar: TensorFlowFloatingPoint)
20
20
public func identity< Scalar> ( _ x: Tensor < Scalar > ) -> Tensor < Scalar > {
21
- return x
21
+ x
22
22
}
23
23
24
24
//===------------------------------------------------------------------------------------------===//
@@ -390,44 +390,97 @@ public extension Tensor {
390
390
return Raw . gatherV2 ( params: self , indices: indices, axis: Tensor < Int32 > ( Int32 ( axis) ) )
391
391
}
392
392
393
- /// Returns slices of this tensor at `indices`, while ignoring the first `batchDims` dimensions
394
- /// that correspond to batch dimensions. The gather is performed along the first non-batch
395
- /// dimension.
393
+ /// Returns slices of this tensor at `indices` along the `axis` dimension , while ignoring the
394
+ /// first `batchDimensionCount` dimensions that correspond to batch dimensions. The gather is
395
+ /// performed along the first non-batch dimension.
396
396
///
397
397
/// Performs similar functionality to `gathering`, except that the resulting tensor shape is
398
- /// now:
399
- /// ```
400
- /// self.shape[..<batchDims] +
401
- /// indices.shape[batchDims...] +
402
- /// self.shape[(batchDims + indices.rank + 1)...]
403
- /// ```
398
+ /// now `shape[..<axis] + indices.shape[batchDimensionCount...] + shape[(axis + 1)...]`.
404
399
///
405
400
/// - Parameters:
406
401
/// - indices: Contains the indices to gather.
407
- /// - batchDims: Number of leading batch dimensions to ignore.
402
+ /// - axis: Dimension along which to gather. Negative values wrap around.
403
+ /// - batchDimensionCount: Number of leading batch dimensions to ignore.
408
404
///
409
- /// - Precondition: `batchDims` must be less than `indices.rank`.
405
+ /// - Precondition: `axis` must be in the range `-rank..<rank`, while also being greater than
406
+ /// or equal to `batchDimensionCount`.
407
+ /// - Precondition: `batchDimensionCount` must be less than `indices.rank`.
410
408
///
411
409
/// - Returns: The gathered tensor.
412
410
@inlinable
413
411
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
414
- func batchGathering< Index: TensorFlowIndex > ( atIndices indices: Tensor < Index > ) -> Tensor {
415
- var batchIndices = indices
416
- var accumulated = Tensor < Index > ( ones: [ ] )
417
- accumulated *= withoutDerivative ( at: shapeTensor) { Tensor < Index > ( $0 [ 1 ] ) }
418
- let dValue = withoutDerivative ( at: shapeTensor) { $0 [ 0 ] }
419
- let dIndices = Tensor < Index > (
420
- rangeFrom: Tensor < Index > ( zeros: [ ] ) ,
421
- to: Tensor < Index > ( dValue) ,
422
- stride: Tensor < Index > ( ones: [ ] )
423
- ) * accumulated
424
- let dShape = Tensor < Int32 > ( concatenating: [
425
- dValue. rankLifted ( ) ,
426
- Tensor < Int32 > ( [ Int32] ( repeating: 1 , count: indices. rank - 1 ) ) ] )
427
- batchIndices += dIndices. reshaped ( toShape: dShape)
412
+ func batchGathering< Index: TensorFlowIndex > (
413
+ atIndices indices: Tensor < Index > ,
414
+ alongAxis axis: Int = 1 ,
415
+ batchDimensionCount: Int = 1
416
+ ) -> Tensor {
417
+ // TODO: precondition(batchDimensionCount >= 0,
418
+ // "'batchDimensionCount' must be non-negative.")
419
+ // TODO: precondition(batchDimensionCount < indices.rank,
420
+ // "'batchDimensionCount' must be less than 'indices.rank'.")
421
+ // TODO: precondition(batchDimensionCount < rank,
422
+ // "'batchDimensionCount' must be less than the tensor's rank.")
423
+
424
+ // Handle the axis argument by transposing the axis dimension so that it is the first
425
+ // non-batch dimension, recursively calling `batchGathering` with `axis = 0`, and then
426
+ // transposing the result to put the pre-axis dimensions before the indices dimensions.
427
+ if axis != batchDimensionCount {
428
+ // Adjust axis to be positive.
429
+ let posAxis = axis < 0 ? axis + rank : axis
430
+
431
+ // TODO: precondition(posAxis >= 0 && posAxis < rank, "'axis' is out of range.")
432
+ // TODO: precondition(batchDimensionCount <= posAxis,
433
+ // "'batchDimensionCount' must be less than or equal to 'axis'.")
434
+
435
+ // Move self[axis] up to self[batchDimensionCount].
436
+ let permutation = Tensor < Int32 > ( concatenating: [
437
+ Tensor < Int32 > ( rangeFrom: 0 , to: Int32 ( batchDimensionCount) , stride: 1 ) ,
438
+ Tensor < Int32 > ( Int32 ( axis) ) . rankLifted ( ) ,
439
+ Tensor < Int32 > ( rangeFrom: Int32 ( batchDimensionCount) , to: Int32 ( posAxis) , stride: 1 ) ,
440
+ Tensor < Int32 > ( rangeFrom: Int32 ( axis) + 1 , to: Int32 ( rank) , stride: 1 ) ] )
441
+ let tensor = transposed ( withPermutations: permutation)
442
+ let result = tensor. batchGathering (
443
+ atIndices: indices,
444
+ alongAxis: batchDimensionCount,
445
+ batchDimensionCount: batchDimensionCount)
446
+
447
+ // Move the result dimensions corresponding to self[batchDimensionCount..<axis] to
448
+ // just before the dimensions corresponding to indices[batchDimensionCount...].
449
+ let start = indices. rank + posAxis - batchDimensionCount
450
+ let resultPermutation = Tensor < Int32 > ( concatenating: [
451
+ Tensor < Int32 > ( rangeFrom: 0 , to: Int32 ( batchDimensionCount) , stride: 1 ) ,
452
+ Tensor < Int32 > ( rangeFrom: Int32 ( indices. rank) , to: Int32 ( start) , stride: 1 ) ,
453
+ Tensor < Int32 > (
454
+ rangeFrom: Int32 ( batchDimensionCount) ,
455
+ to: Int32 ( indices. rank) ,
456
+ stride: 1 ) ,
457
+ Tensor < Int32 > ( rangeFrom: Int32 ( start) , to: Int32 ( result. rank) , stride: 1 ) ] )
458
+ return result. transposed ( withPermutations: resultPermutation)
459
+ }
460
+
461
+ let batchIndices : Tensor < Index > = withoutDerivative ( at: {
462
+ var batchIndices = indices
463
+ var accumulated = Tensor < Index > ( ones: [ ] )
464
+ for d in ( 1 ... batchDimensionCount) . reversed ( ) {
465
+ accumulated *= Tensor < Index > ( self . shapeTensor [ d] )
466
+ let dValue = self . shapeTensor [ d - 1 ]
467
+ let dIndices = Tensor < Index > (
468
+ rangeFrom: Tensor < Index > ( zeros: [ ] ) ,
469
+ to: Tensor < Index > ( dValue) ,
470
+ stride: Tensor < Index > ( ones: [ ] )
471
+ ) * accumulated
472
+ let dShape = Tensor < Int32 > ( concatenating: [
473
+ Tensor < Int32 > ( [ Int32] ( repeating: 1 , count: d - 1 ) ) ,
474
+ dValue. rankLifted ( ) ,
475
+ Tensor < Int32 > ( [ Int32] ( repeating: 1 , count: indices. rank - 1 ) ) ] )
476
+ batchIndices += dIndices. reshaped ( toShape: dShape)
477
+ }
478
+ return batchIndices
479
+ } ( ) )
480
+
428
481
let flatIndices = batchIndices. flattened ( )
429
- let outerShape = withoutDerivative ( at : shapeTensor ) { $0 [ 2 ... ] }
430
- let innerShape = withoutDerivative ( at : shapeTensor) { $0 [ ..< 2 ] } . product ( squeezingAxes: [ 0 ] )
482
+ let outerShape = shapeTensor [ ( batchDimensionCount + 1 ) ... ]
483
+ let innerShape = shapeTensor [ ..< ( batchDimensionCount + 1 ) ] . product ( squeezingAxes: [ 0 ] )
431
484
let flatTensor = reshaped ( toShape: innerShape. rankLifted ( ) . concatenated ( with: outerShape) )
432
485
let flatResult = flatTensor. gathering ( atIndices: flatIndices)
433
486
return flatResult. reshaped ( toShape: indices. shapeTensor. concatenated ( with: outerShape) )
0 commit comments