Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2be7fd9

Browse files
eaplataniosrxwei
authored andcommitted
Fixed a bug in 'Tensor.batchGathering(atIndices:alongAxis:batchDimensionsCount:)'. (#359)
1 parent 1f5e7fa commit 2be7fd9

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,12 +414,15 @@ public extension Tensor {
414414
alongAxis axis: Int = 1,
415415
batchDimensionCount: Int = 1
416416
) -> Tensor {
417-
// TODO: precondition(batchDimensionCount >= 0,
418-
// "'batchDimensionCount' must be non-negative.")
419-
// TODO: precondition(batchDimensionCount < indices.rank,
420-
// "'batchDimensionCount' must be less than 'indices.rank'.")
421-
// TODO: precondition(batchDimensionCount < rank,
422-
// "'batchDimensionCount' must be less than the tensor's rank.")
417+
precondition(batchDimensionCount >= 0, "'batchDimensionCount' must be non-negative.")
418+
precondition(
419+
batchDimensionCount < indices.rank,
420+
"'batchDimensionCount' must be less than 'indices.rank'.")
421+
withoutDerivative(at: rank) {
422+
precondition(
423+
batchDimensionCount < $0,
424+
"'batchDimensionCount' must be less than the tensor's rank.")
425+
}
423426

424427
// Handle the axis argument by transposing the axis dimension so that it is the first
425428
// non-batch dimension, recursively calling `batchGathering` with `axis = 0`, and then
@@ -472,7 +475,7 @@ public extension Tensor {
472475
let dShape = Tensor<Int32>(concatenating: [
473476
Tensor<Int32>([Int32](repeating: 1, count: d - 1)),
474477
dValue.rankLifted(),
475-
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
478+
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - d))])
476479
batchIndices += dIndices.reshaped(toShape: dShape)
477480
}
478481
return batchIndices

Tests/TensorFlowTests/OperatorTests/BasicTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ final class BasicOperatorTests: XCTestCase {
2727
[1.0, 2.0, 3.0],
2828
[4.0, 5.0, 6.0]]])
2929
let y = x.batchGathering(
30-
atIndices: Tensor<Int32>([1, 0]),
30+
atIndices: Tensor<Int32>([[[1], [0]]]),
3131
alongAxis: 2,
3232
batchDimensionCount: 2)
33-
XCTAssertEqual(y, Tensor<Float>([2.0, 4.0]))
33+
XCTAssertEqual(y, Tensor<Float>([[[2.0], [4.0]]]))
3434
}
3535

3636
func testPadded() {

0 commit comments

Comments
 (0)