Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add vjps that don't use slicing or gather on shapes. #556

Merged
merged 1 commit into from
Nov 15, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1890,7 +1890,7 @@ public extension Tensor where Scalar: Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
@differentiable(wrt: self, vjp: _vjpMean(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
func mean(squeezingAxes axes: [Int]) -> Tensor {
// TODO(TF-433): Remove workaround for differentiating `map`.
let axes = {axes.map(Int32.init)}()
Expand Down Expand Up @@ -1927,7 +1927,7 @@ public extension Tensor where Scalar: Numeric {
/// - Parameter axes: The dimensions to reduce.
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
@differentiable(wrt: self, vjp: _vjpMean(alongAxes:) where Scalar: TensorFlowFloatingPoint)
func mean(alongAxes axes: [Int]) -> Tensor {
// TODO(TF-433): Remove workaround for differentiating `map`.
let axes = {axes.map(Int32.init)}()
Expand Down Expand Up @@ -2201,6 +2201,31 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
})
}

// Specialization to avoid _Raw.gather on shapes when axes is known to be
// [Int].
@inlinable
func _vjpMean(alongAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
let value = mean(alongAxes: axes)
// Cache shape because it is a computed property.
let cachedShape = shape
let count = axes.map { cachedShape[$0] }.reduce(1, *)
return (value, { [shape = shapeTensor] in $0.broadcasted(toShape: shape) / Tensor(Scalar(count)) })
}

// Specialization to avoid _Raw.gather on shapes when axes is known to be
// [Int].
@inlinable
func _vjpMean(squeezingAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
let value = mean(squeezingAxes: axes)
// Cache shape because it is a computed property.
let cachedShape = shape
let count = axes.map { cachedShape[$0] }.reduce(1, *)
return (value, { [shape = shapeTensor] v in
let unsqueezed = v.expandingShape(at: axes)
return unsqueezed.broadcasted(toShape: shape) / Tensor(Scalar(count))
})
}

@inlinable
func _vjpCumulativeSum(
alongAxis axis: Tensor<Int32>,
Expand Down Expand Up @@ -2606,7 +2631,7 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
transposed transposeRhs: Bool = false
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
let value = matmul(lhs, transposed: transposeLhs, rhs, transposed: transposeRhs)
return (value, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
return (value, { [lhsShape = lhs.shape, rhsShape = rhs.shape] v in
let (lhsGrad, rhsGrad): (Tensor<Scalar>, Tensor<Scalar>)
switch (transposeLhs, transposeRhs) {
case (false, false):
Expand All @@ -2622,13 +2647,15 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
lhsGrad = matmul(v, transposed: true, rhs, transposed: true)
rhsGrad = matmul(lhs, transposed: true, v, transposed: true)
}
let lhsRank = lhsShape.shape[0] - 2
let rhsRank = rhsShape.shape[0] - 2
let lhsRank = lhsShape.rank - 2
let rhsRank = rhsShape.rank - 2
let (lhsAxes, rhsAxes) = _Raw.broadcastGradientArgs(
s0: lhsShape[..<lhsRank],
s1: rhsShape[..<rhsRank])
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
s0: Tensor<Int32>(lhsShape.dimensions[..<lhsRank].map { Int32($0) } ),
s1: Tensor<Int32>(rhsShape.dimensions[..<rhsRank].map { Int32($0) } ))
let lhsShapeTensor = Tensor<Int32>(lhsShape.dimensions.map { Int32($0) })
let rhsShapeTensor = Tensor<Int32>(rhsShape.dimensions.map { Int32($0) })
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShapeTensor),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShapeTensor))
})
}

Expand Down