Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit dbe7835

Browse files
committed
Add vjps that don't use slicing or gather on shapes.
1 parent 1702f52 commit dbe7835

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,7 @@ public extension Tensor where Scalar: Numeric {
18901890
/// - Parameter axes: The dimensions to reduce.
18911891
/// - Precondition: Each value in `axes` must be in the range `-rank...rank`.
18921892
@inlinable
1893-
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
1893+
@differentiable(wrt: self, vjp: _vjpMean(squeezingAxes:) where Scalar: TensorFlowFloatingPoint)
18941894
func mean(squeezingAxes axes: [Int]) -> Tensor {
18951895
// TODO(TF-433): Remove workaround for differentiating `map`.
18961896
let axes = {axes.map(Int32.init)}()
@@ -1927,7 +1927,7 @@ public extension Tensor where Scalar: Numeric {
19271927
/// - Parameter axes: The dimensions to reduce.
19281928
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
19291929
@inlinable
1930-
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
1930+
@differentiable(wrt: self, vjp: _vjpMean(alongAxes:) where Scalar: TensorFlowFloatingPoint)
19311931
func mean(alongAxes axes: [Int]) -> Tensor {
19321932
// TODO(TF-433): Remove workaround for differentiating `map`.
19331933
let axes = {axes.map(Int32.init)}()
@@ -2201,6 +2201,31 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
22012201
})
22022202
}
22032203

2204+
// Specialization to avoid _Raw.gather on shapes when axes is known to be
2205+
// [Int].
2206+
@inlinable
2207+
func _vjpMean(alongAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
2208+
let value = mean(alongAxes: axes)
2209+
// Cache shape because it is a computed property.
2210+
let cachedShape = shape
2211+
let count = axes.map { cachedShape[$0] }.reduce(1, *)
2212+
return (value, { [shape = shapeTensor] in $0.broadcasted(toShape: shape) / Tensor(Scalar(count)) })
2213+
}
2214+
2215+
// Specialization to avoid _Raw.gather on shapes when axes is known to be
2216+
// [Int].
2217+
@inlinable
2218+
func _vjpMean(squeezingAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
2219+
let value = mean(squeezingAxes: axes)
2220+
// Cache shape because it is a computed property.
2221+
let cachedShape = shape
2222+
let count = axes.map { cachedShape[$0] }.reduce(1, *)
2223+
return (value, { [shape = shapeTensor] v in
2224+
let unsqueezed = v.expandingShape(at: axes)
2225+
return unsqueezed.broadcasted(toShape: shape) / Tensor(Scalar(count))
2226+
})
2227+
}
2228+
22042229
@inlinable
22052230
func _vjpCumulativeSum(
22062231
alongAxis axis: Tensor<Int32>,
@@ -2606,7 +2631,7 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
26062631
transposed transposeRhs: Bool = false
26072632
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
26082633
let value = matmul(lhs, transposed: transposeLhs, rhs, transposed: transposeRhs)
2609-
return (value, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
2634+
return (value, { [lhsShape = lhs.shape, rhsShape = rhs.shape] v in
26102635
let (lhsGrad, rhsGrad): (Tensor<Scalar>, Tensor<Scalar>)
26112636
switch (transposeLhs, transposeRhs) {
26122637
case (false, false):
@@ -2622,13 +2647,15 @@ internal func _vjpMatmul<Scalar: TensorFlowFloatingPoint>(
26222647
lhsGrad = matmul(v, transposed: true, rhs, transposed: true)
26232648
rhsGrad = matmul(lhs, transposed: true, v, transposed: true)
26242649
}
2625-
let lhsRank = lhsShape.shape[0] - 2
2626-
let rhsRank = rhsShape.shape[0] - 2
2650+
let lhsRank = lhsShape.rank - 2
2651+
let rhsRank = rhsShape.rank - 2
26272652
let (lhsAxes, rhsAxes) = _Raw.broadcastGradientArgs(
2628-
s0: lhsShape[..<lhsRank],
2629-
s1: rhsShape[..<rhsRank])
2630-
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
2631-
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
2653+
s0: Tensor<Int32>(lhsShape.dimensions[..<lhsRank].map { Int32($0) } ),
2654+
s1: Tensor<Int32>(rhsShape.dimensions[..<rhsRank].map { Int32($0) } ))
2655+
let lhsShapeTensor = Tensor<Int32>(lhsShape.dimensions.map { Int32($0) })
2656+
let rhsShapeTensor = Tensor<Int32>(rhsShape.dimensions.map { Int32($0) })
2657+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShapeTensor),
2658+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShapeTensor))
26322659
})
26332660
}
26342661

0 commit comments

Comments
 (0)