Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Cleaned up 'Swift.withoutDerivative(at:in:)' uses. #285

Merged
merged 5 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Layers/Recurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ public struct RNN<Cell: RNNCell>: Layer {

@differentiable(wrt: (self, inputs))
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
return self(inputs, initialState: Swift.withoutDerivative(at: cell.zeroState, in: { $0 }))
return self(inputs, initialState: withoutDerivative(at: cell.zeroState))
}

/* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
Expand Down
10 changes: 5 additions & 5 deletions Sources/TensorFlow/Operators/Basic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ public extension Tensor {
func batchGathering<Index: TensorFlowIndex>(atIndices indices: Tensor<Index>) -> Tensor {
var batchIndices = indices
var accumulated = Tensor<Index>(ones: [])
accumulated *= Swift.withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
let dValue = Swift.withoutDerivative(at: shapeTensor) { $0[0] }
accumulated *= withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
let dValue = withoutDerivative(at: shapeTensor) { $0[0] }
let dIndices = Tensor<Index>(
rangeFrom: Tensor<Index>(zeros: []),
to: Tensor<Index>(dValue),
Expand All @@ -426,8 +426,8 @@ public extension Tensor {
Tensor<Int32>([Int32](repeating: 1, count: indices.rank - 1))])
batchIndices += dIndices.reshaped(toShape: dShape)
let flatIndices = batchIndices.flattened()
let outerShape = Swift.withoutDerivative(at: shapeTensor) { $0[2...] }
let innerShape = Swift.withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
let outerShape = withoutDerivative(at: shapeTensor) { $0[2...] }
let innerShape = withoutDerivative(at: shapeTensor) { $0[..<2] }.product(squeezingAxes: [0])
let flatTensor = reshaped(toShape: innerShape.rankLifted().concatenated(with: outerShape))
let flatResult = flatTensor.gathering(atIndices: flatIndices)
return flatResult.reshaped(toShape: indices.shapeTensor.concatenated(with: outerShape))
Expand Down Expand Up @@ -470,7 +470,7 @@ public extension Tensor {
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func gathering(where mask: Tensor<Bool>, alongAxis axis: Int = 0) -> Tensor {
precondition(mask.rank != 0, "The boolean mask cannot be a scalar.")
let posAxis = Swift.withoutDerivative(at: self.rank) { r in axis < 0 ? axis + r : axis }
let posAxis = withoutDerivative(at: self.rank) { r in axis < 0 ? axis + r : axis }
let leadingSize = shapeTensor[posAxis ..< posAxis + mask.rank].product().rankLifted()
let reshapedTensor = reshaped(
toShape: Tensor<Int32>(concatenating: [
Expand Down
10 changes: 5 additions & 5 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1932,13 +1932,13 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func logSumExp(squeezingAxes axes: Tensor<Int32>) -> Tensor {
let rawMax = max(alongAxes: axes)
let offset = Swift.withoutDerivative(at: rawMax) { rawMax in
let offset = withoutDerivative(at: rawMax) { rawMax in
rawMax.replacing(
with: Tensor<Scalar>(zerosLike: rawMax),
where: rawMax.isFinite)
}
let result = TensorFlow.log(TensorFlow.exp(self - offset).sum(squeezingAxes: axes))
let resultShape = Swift.withoutDerivative(at: result.shapeTensor, in: identity)
let resultShape = withoutDerivative(at: result.shapeTensor)
return result + offset.reshaped(toShape: resultShape)
}

Expand All @@ -1954,7 +1954,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func logSumExp(squeezingAxes axes: [Int]) -> Tensor {
// TODO(TF-433): Remove workaround for differentiating `map`.
let axes = Swift.withoutDerivative(at: axes) { $0.map(Int32.init) }
let axes = withoutDerivative(at: axes) { $0.map(Int32.init) }
return logSumExp(squeezingAxes: Tensor<Int32>(axes))
}

Expand Down Expand Up @@ -1996,7 +1996,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func logSumExp(alongAxes axes: Tensor<Int32>) -> Tensor {
let rawMax = max(alongAxes: axes)
let offset = Swift.withoutDerivative(at: rawMax) { rawMax in
let offset = withoutDerivative(at: rawMax) { rawMax in
rawMax.replacing(
with: Tensor<Scalar>(zerosLike: rawMax),
where: rawMax.isFinite)
Expand All @@ -2018,7 +2018,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func logSumExp(alongAxes axes: [Int]) -> Tensor {
// TODO(TF-433): Remove workaround for differentiating `map`.
let axes = Swift.withoutDerivative(at: axes) { $0.map(Int32.init) }
let axes = withoutDerivative(at: axes) { $0.map(Int32.init) }
return logSumExp(alongAxes: Tensor<Int32>(axes))
}

Expand Down