Skip to content

Commit 9e089ac

Browse files
sguggerrxwei
authored andcommitted
[TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) (#24164)
* [TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:) sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the bawckward pass because the gradients weren't unsqueezed before being broadcast. Note that this could be refactored nicely if we had a function that took a list of ints for `expandingShape`. Second note: I may be wrong, but it seems like `_vjpMean(squeezingAxes axes: [Int])` is never used and only the Tensor<Int32> version is. * Remove unused `_vjpMean` function. * Update Gradients.swift * Add test * Minor edit for consistency.
1 parent bb0b90e commit 9e089ac

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,11 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
579579
squeezingAxes axes: Tensor<Int32>
580580
) -> (Tensor, (Tensor) -> Tensor) {
581581
let value = sum(squeezingAxes: axes)
582-
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
582+
return (value, { [shape = shapeTensor] in
583+
var res = $0
584+
for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
585+
return res.broadcast(toShape: shape)
586+
})
583587
}
584588

585589
@inlinable
@@ -591,23 +595,16 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
591595
})
592596
}
593597

594-
@inlinable
595-
func _vjpMean(squeezingAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
596-
let value = mean(squeezingAxes: axes)
597-
return (value, { [shape = shapeTensor,
598-
count = axes.map { shape[$0] }.reduce(1, *)] in
599-
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
600-
})
601-
}
602-
603598
@inlinable
604599
func _vjpMean(
605600
squeezingAxes axes: Tensor<Int32>
606601
) -> (Tensor, (Tensor) -> Tensor) {
607602
let value = mean(squeezingAxes: axes)
608603
let count = Raw.gather(params: shapeTensor, indices: axes).product()
609604
return (value, { [shape = shapeTensor] in
610-
$0.broadcast(toShape: shape) / Tensor(count)
605+
var res = $0
606+
for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
607+
return res.broadcast(toShape: shape) / Tensor(count)
611608
})
612609
}
613610
}

test/TensorFlowRuntime/tensor_autodiff_runtime.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,38 +98,39 @@ TensorADTests.testAllBackends("Abs") {
9898
TensorADTests.testAllBackends("sum") {
9999
let input = Tensor<Float>(repeating: 42, shape: [2, 2])
100100
let sumPullbackScalar = pullback(at: input) { (a: Tensor<Float>) in a.sum() }
101+
let sumPullbackSqueezingAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(squeezingAxes: 0, 1) }
101102
let sumPullbackAlongAxes = pullback(at: input) { (a: Tensor<Float>) in a.sum(alongAxes: 0, 1) }
102103

103104
let expected = Tensor<Float>(ones: [2, 2])
104105
expectEqual(expected, sumPullbackScalar(Tensor(1)))
105-
// expectEqual(expected, sumPullbackSqueezingAxes(Tensor(1)))
106+
expectEqual(expected, sumPullbackSqueezingAxes(Tensor(1)))
106107
expectEqual(expected, sumPullbackAlongAxes(Tensor(1)))
107108
expectEqual(expected * 3, sumPullbackScalar(Tensor(3)))
108-
// expectEqual(expected * 3, sumPullbackSqueezingAxes(Tensor(3)))
109+
expectEqual(expected * 3, sumPullbackSqueezingAxes(Tensor(3)))
109110
expectEqual(expected * 3, sumPullbackAlongAxes(Tensor(3)))
110111
}
111112

112113
TensorADTests.testAllBackends("mean") {
113114
let meanGradScalar = gradient { (a: Tensor<Float>) in a.mean() }
114-
// let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean(squeezingAxes: 0, 1) }
115+
let meanGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.mean(squeezingAxes: 0, 1) }
115116
let meanGradAlongAxes = gradient { (a: Tensor<Float>) in a.mean(alongAxes: 0, 1) }
116117

117118
let input = Tensor<Float>(ones: [2, 2])
118119
let expected = Tensor<Float>(repeating: 0.25, shape: [2, 2])
119120
expectEqual(expected, meanGradScalar(input))
120-
// expectEqual(expected, meanGradSqueezingAxes(input))
121+
expectEqual(expected, meanGradSqueezingAxes(input))
121122
expectEqual(expected, meanGradAlongAxes(input))
122123
}
123124

124125
TensorADTests.testAllBackends("variance") {
125126
let varianceGradScalar = gradient { (a: Tensor<Float>) in a.variance() }
126-
// let varianceGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.variance(squeezingAxes: 0, 1) }
127+
let varianceGradSqueezingAxes = gradient { (a: Tensor<Float>) in a.variance(squeezingAxes: 0, 1) }
127128
let varianceGradAlongAxes = gradient { (a: Tensor<Float>) in a.variance(alongAxes: 0, 1) }
128129

129130
let input: Tensor<Float> = [[1, 2], [3, 4]]
130131
let expected: Tensor<Float> = [[-0.75, -0.25], [0.25, 0.75]]
131132
expectEqual(expected, varianceGradScalar(input))
132-
// expectEqual(expected, varianceGradSqueezingAxes(input))
133+
expectEqual(expected, varianceGradSqueezingAxes(input))
133134
expectEqual(expected, varianceGradAlongAxes(input))
134135
}
135136

0 commit comments

Comments
 (0)