@@ -414,12 +414,15 @@ public extension Tensor {
414
414
alongAxis axis: Int = 1 ,
415
415
batchDimensionCount: Int = 1
416
416
) -> Tensor {
417
- // TODO: precondition(batchDimensionCount >= 0,
418
- // "'batchDimensionCount' must be non-negative.")
419
- // TODO: precondition(batchDimensionCount < indices.rank,
420
- // "'batchDimensionCount' must be less than 'indices.rank'.")
421
- // TODO: precondition(batchDimensionCount < rank,
422
- // "'batchDimensionCount' must be less than the tensor's rank.")
417
+ precondition ( batchDimensionCount >= 0 , " 'batchDimensionCount' must be non-negative. " )
418
+ precondition (
419
+ batchDimensionCount < indices. rank,
420
+ " 'batchDimensionCount' must be less than 'indices.rank'. " )
421
+ withoutDerivative ( at: rank) {
422
+ precondition (
423
+ batchDimensionCount < $0,
424
+ " 'batchDimensionCount' must be less than the tensor's rank. " )
425
+ }
423
426
424
427
// Handle the axis argument by transposing the axis dimension so that it is the first
425
428
// non-batch dimension, recursively calling `batchGathering` with `axis = 0`, and then
@@ -472,7 +475,7 @@ public extension Tensor {
472
475
let dShape = Tensor < Int32 > ( concatenating: [
473
476
Tensor < Int32 > ( [ Int32] ( repeating: 1 , count: d - 1 ) ) ,
474
477
dValue. rankLifted ( ) ,
475
- Tensor < Int32 > ( [ Int32] ( repeating: 1 , count: indices. rank - 1 ) ) ] )
478
+ Tensor < Int32 > ( [ Int32] ( repeating: 1 , count: indices. rank - d ) ) ] )
476
479
batchIndices += dIndices. reshaped ( toShape: dShape)
477
480
}
478
481
return batchIndices
0 commit comments