@@ -383,7 +383,10 @@ public extension Tensor {
383
383
/// - Returns: The gathered tensor.
384
384
@inlinable
385
385
@differentiable ( wrt: self , vjp: _vjpGathering where Scalar : TensorFlowFloatingPoint)
386
- func gathering( atIndices indices: Tensor < Int32 > , alongAxis axis: Int = 0 ) -> Tensor {
386
+ func gathering< Index: TensorFlowIndex > (
387
+ atIndices indices: Tensor < Index > ,
388
+ alongAxis axis: Int = 0
389
+ ) -> Tensor {
387
390
return Raw . gatherV2 ( params: self , indices: indices, axis: Tensor < Int32 > ( Int32 ( axis) ) )
388
391
}
389
392
@@ -408,15 +411,15 @@ public extension Tensor {
408
411
/// - Returns: The gathered tensor.
409
412
@inlinable
410
413
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
411
- func batchGathering( atIndices indices: Tensor < Int32 > ) -> Tensor {
414
+ func batchGathering< Index : TensorFlowIndex > ( atIndices indices: Tensor < Index > ) -> Tensor {
412
415
var batchIndices = indices
413
- var accumulated = Tensor < Int32 > ( ones: [ ] )
414
- accumulated *= Swift . withoutDerivative ( at: shapeTensor) { $0 [ 1 ] }
416
+ var accumulated = Tensor < Index > ( ones: [ ] )
417
+ accumulated *= Swift . withoutDerivative ( at: shapeTensor) { Tensor < Index > ( $0 [ 1 ] ) }
415
418
let dValue = Swift . withoutDerivative ( at: shapeTensor) { $0 [ 0 ] }
416
- let dIndices = Tensor < Int32 > (
417
- rangeFrom: Tensor < Int32 > ( zeros: [ ] ) ,
418
- to: dValue,
419
- stride: Tensor < Int32 > ( ones: [ ] )
419
+ let dIndices = Tensor < Index > (
420
+ rangeFrom: Tensor < Index > ( zeros: [ ] ) ,
421
+ to: Tensor < Index > ( dValue) ,
422
+ stride: Tensor < Index > ( ones: [ ] )
420
423
) * accumulated
421
424
let dShape = Tensor < Int32 > ( concatenating: [
422
425
dValue. rankLifted ( ) ,
@@ -519,8 +522,8 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
519
522
}
520
523
521
524
@inlinable
522
- func _vjpGathering(
523
- atIndices indices: Tensor < Int32 > ,
525
+ func _vjpGathering< Index : TensorFlowIndex > (
526
+ atIndices indices: Tensor < Index > ,
524
527
alongAxis axis: Int = 0
525
528
) -> ( Tensor , ( Tensor ) -> Tensor ) {
526
529
let result = gathering ( atIndices: indices, alongAxis: axis)
0 commit comments