Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1138b08

Browse files
eaplataniosrxwei
authored andcommitted
Cleaned up 'Swift.withoutDerivative(at:in:)' uses. (#285)
1 parent 6045403 commit 1138b08

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ public struct RNN<Cell: RNNCell>: Layer {
264264

265265
@differentiable(wrt: (self, inputs))
266266
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
267-
return self(inputs, initialState: Swift.withoutDerivative(at: cell.zeroState, in: { $0 }))
267+
return self(inputs, initialState: withoutDerivative(at: cell.zeroState))
268268
}
269269

270270
/* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ public extension Tensor {
414414
func batchGathering<Index: TensorFlowIndex>(atIndices indices: Tensor<Index>) -> Tensor {
415415
var batchIndices = indices
416416
var accumulated = Tensor<Index>(ones: [])
417-
accumulated *= Swift.withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
418-
let dValue = Swift.withoutDerivative(at: shapeTensor) { $0[0] }
417+
accumulated *= withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
418+
let dValue = withoutDerivative(at: shapeTensor) { $0[0] }
419419
let dIndices = Tensor<Index>(
420420
rangeFrom: Tensor<Index>(zeros: []),
421421
to: Tensor<Index>(dValue),
@@ -426,8 +426,8 @@ public extension Tensor {
426426
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
427427
batchIndices += dIndices.reshaped(toShape: dShape)
428428
let flatIndices = batchIndices.flattened()
429-
let outerShape = Swift.withoutDerivative(at: shapeTensor) { $0[2...] }
430-
let innerShape = Swift.withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
429+
let outerShape = withoutDerivative(at: shapeTensor) { $0[2...] }
430+
let innerShape = withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
431431
let flatTensor = reshaped(toShape: innerShape.rankLifted().concatenated(with: outerShape))
432432
let flatResult = flatTensor.gathering(atIndices: flatIndices)
433433
return flatResult.reshaped(toShape: indices.shapeTensor.concatenated(with: outerShape))
@@ -470,7 +470,7 @@ public extension Tensor {
470470
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
471471
func gathering(where mask: Tensor<Bool>, alongAxis axis: Int = 0) -> Tensor {
472472
precondition(mask.rank != 0, "The boolean mask cannot be a scalar.")
473-
let posAxis = Swift.withoutDerivative(at: self.rank) { r in axis < 0 ? axis + r : axis }
473+
let posAxis = withoutDerivative(at: self.rank) { r in axis < 0 ? axis + r : axis }
474474
let leadingSize = shapeTensor[posAxis ..< posAxis + mask.rank].product().rankLifted()
475475
let reshapedTensor = reshaped(
476476
toShape: Tensor<Int32>(concatenating: [

Sources/TensorFlow/Operators/Math.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,13 +1932,13 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
19321932
@differentiable(wrt: self)
19331933
func logSumExp(squeezingAxes axes: Tensor<Int32>) -> Tensor {
19341934
let rawMax = max(alongAxes: axes)
1935-
let offset = Swift.withoutDerivative(at: rawMax) { rawMax in
1935+
let offset = withoutDerivative(at: rawMax) { rawMax in
19361936
rawMax.replacing(
19371937
with: Tensor<Scalar>(zerosLike: rawMax),
19381938
where: rawMax.isFinite)
19391939
}
19401940
let result = TensorFlow.log(TensorFlow.exp(self - offset).sum(squeezingAxes: axes))
1941-
let resultShape = Swift.withoutDerivative(at: result.shapeTensor, in: identity)
1941+
let resultShape = withoutDerivative(at: result.shapeTensor)
19421942
return result + offset.reshaped(toShape: resultShape)
19431943
}
19441944

@@ -1954,7 +1954,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
19541954
@differentiable(wrt: self)
19551955
func logSumExp(squeezingAxes axes: [Int]) -> Tensor {
19561956
// TODO(TF-433): Remove workaround for differentiating `map`.
1957-
let axes = Swift.withoutDerivative(at: axes) { $0.map(Int32.init) }
1957+
let axes = withoutDerivative(at: axes) { $0.map(Int32.init) }
19581958
return logSumExp(squeezingAxes: Tensor<Int32>(axes))
19591959
}
19601960

@@ -1996,7 +1996,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
19961996
@differentiable(wrt: self)
19971997
func logSumExp(alongAxes axes: Tensor<Int32>) -> Tensor {
19981998
let rawMax = max(alongAxes: axes)
1999-
let offset = Swift.withoutDerivative(at: rawMax) { rawMax in
1999+
let offset = withoutDerivative(at: rawMax) { rawMax in
20002000
rawMax.replacing(
20012001
with: Tensor<Scalar>(zerosLike: rawMax),
20022002
where: rawMax.isFinite)
@@ -2018,7 +2018,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
20182018
@differentiable(wrt: self)
20192019
func logSumExp(alongAxes axes: [Int]) -> Tensor {
20202020
// TODO(TF-433): Remove workaround for differentiating `map`.
2021-
let axes = Swift.withoutDerivative(at: axes) { $0.map(Int32.init) }
2021+
let axes = withoutDerivative(at: axes) { $0.map(Int32.init) }
20222022
return logSumExp(alongAxes: Tensor<Int32>(axes))
20232023
}
20242024

0 commit comments

Comments
 (0)