Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5f9d1fa

Browse files
eaplataniosrxwei
authored andcommitted
Made 'Tensor.gathering(where:alongAxis:)' differentiable. (#271)
1 parent d264701 commit 5f9d1fa

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,10 @@ public extension Tensor {
464464
/// - Returns: `(self.rank - K + 1)`-dimensional tensor populated by entries in this tensor
465465
/// corresponding to `true` values in `mask`.
466466
@inlinable
467-
// @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
467+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
468468
func gathering(where mask: Tensor<Bool>, alongAxis axis: Int = 0) -> Tensor {
469469
precondition(mask.rank != 0, "The boolean mask cannot be a scalar.")
470-
// TODO: Remove once control flow AD is supported.
471-
let rank = self.rank
472-
let posAxis = { axis < 0 ? axis + rank : axis }()
470+
let posAxis = Swift.withoutDerivative(at: self.rank) { r in axis < 0 ? axis + r : axis }
473471
let leadingSize = shapeTensor[posAxis ..< posAxis + mask.rank].product().rankLifted()
474472
let reshapedTensor = reshaped(
475473
toShape: Tensor<Int32>(concatenating: [

0 commit comments

Comments
 (0)