Skip to content

Commit f194110

Browse files
eaplataniosrxwei
authored andcommitted
Improved derivative performance for broadcasted operations. (tensorflow#142)
Re-implementation of swiftlang/swift#24408. In the pullback for operators that broadcast, use `Raw.broadcastGradientArgs(s0:s1:)` to compute reduction indices instead of using the inefficient `unbroadcast(to:)`.
1 parent 89b8548 commit f194110

File tree

4 files changed

+62
-83
lines changed

4 files changed

+62
-83
lines changed

Sources/TensorFlow/Core/DifferentialOperators.swift

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ public extension Differentiable {
2121
func gradient<R: TensorFlowFloatingPoint>(
2222
in f: @differentiable (Self) -> Tensor<R>
2323
) -> TangentVector {
24-
return self.pullback(in: f)(Tensor<R>(1))
24+
return self.valueWithGradient(in: f).1
2525
}
2626

2727
@inlinable
2828
func valueWithGradient<R: TensorFlowFloatingPoint>(
2929
in f: @differentiable (Self) -> Tensor<R>
3030
) -> (value: Tensor<R>, gradient: TangentVector) {
3131
let (y, pb) = self.valueWithPullback(in: f)
32+
precondition(y.rank == 0)
3233
return (y, pb(Tensor<R>(1)))
3334
}
3435

@@ -37,7 +38,7 @@ public extension Differentiable {
3738
at x: T,
3839
in f: @differentiable (Self, T) -> Tensor<R>
3940
) -> (TangentVector, T.TangentVector) {
40-
return self.pullback(at: x, in: f)(Tensor<R>(1))
41+
return self.valueWithGradient(at: x, in: f).1
4142
}
4243

4344
@inlinable
@@ -46,6 +47,7 @@ public extension Differentiable {
4647
in f: @differentiable (Self, T) -> Tensor<R>
4748
) -> (value: Tensor<R>, gradient: (TangentVector, T.TangentVector)) {
4849
let (y, pb) = self.valueWithPullback(at: x, in: f)
50+
precondition(y.rank == 0)
4951
return (y, pb(Tensor<R>(1)))
5052
}
5153
}
@@ -63,6 +65,7 @@ public func valueWithGradient<T, R>(
6365
) -> (value: Tensor<R>, gradient: T.TangentVector)
6466
where T: Differentiable, R: TensorFlowFloatingPoint {
6567
let (y, pullback) = valueWithPullback(at: x, in: f)
68+
precondition(y.rank == 0)
6669
return (y, pullback(Tensor<R>(1)))
6770
}
6871

@@ -74,6 +77,7 @@ public func valueWithGradient<T, U, R>(
7477
) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector))
7578
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
7679
let (y, pullback) = valueWithPullback(at: x, y, in: f)
80+
precondition(y.rank == 0)
7781
return (y, pullback(Tensor<R>(1)))
7882
}
7983

@@ -86,6 +90,7 @@ public func valueWithGradient<T, U, R>(
8690
// ) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
8791
// where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint {
8892
// let (y, pullback) = valueWithPullback(at: x, y, z, in: f)
93+
// precondition(y.rank == 0)
8994
// return (y, pullback(Tensor<R>(1)))
9095
// }
9196

@@ -124,7 +129,7 @@ public func gradient<T, R>(
124129
at x: T,
125130
in f: @differentiable (T) -> Tensor<R>
126131
) -> T.TangentVector where T: Differentiable, R: TensorFlowFloatingPoint {
127-
return pullback(at: x, in: f)(Tensor<R>(1))
132+
return valueWithGradient(at: x, in: f).1
128133
}
129134

130135
@inlinable
@@ -134,7 +139,7 @@ public func gradient<T, U, R>(
134139
in f: @differentiable (T, U) -> Tensor<R>
135140
) -> (T.TangentVector, U.TangentVector)
136141
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
137-
return pullback(at: x, y, in: f)(Tensor<R>(1))
142+
return valueWithGradient(at: x, y, in: f).1
138143
}
139144

140145
// @inlinable
@@ -145,7 +150,7 @@ public func gradient<T, U, R>(
145150
// in f: @differentiable (T, U, V) -> Tensor<R>
146151
// ) -> (T.TangentVector, U.TangentVector, V.TangentVector)
147152
// where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint {
148-
// return pullback(at: x, y, z, in: f)(Tensor<R>(1))
153+
// return valueWithGradient(at: x, y, z, in: f).1
149154
// }
150155

151156
// Gradient (curried)

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -526,20 +526,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
526526
lhs: Tensor,
527527
rhs: Tensor
528528
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
529-
return (lhs + rhs, { [
530-
lhsShape = lhs.shape,
531-
rhsShape = rhs.shape,
532-
lhsShapeTensor = lhs.shapeTensor,
533-
rhsShapeTensor = rhs.shapeTensor] v in
534-
var lhsGrad = v
535-
var rhsGrad = v
536-
if lhsGrad.shape != lhsShape {
537-
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
538-
}
539-
if rhsGrad.shape != rhsShape {
540-
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
541-
}
542-
return (lhsGrad, rhsGrad)
529+
return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
530+
let lhsGrad = v
531+
let rhsGrad = lhsGrad
532+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
533+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
534+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
543535
})
544536
}
545537

@@ -548,20 +540,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
548540
lhs: Tensor,
549541
rhs: Tensor
550542
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
551-
return (lhs - rhs, { [
552-
lhsShape = lhs.shape,
553-
rhsShape = rhs.shape,
554-
lhsShapeTensor = lhs.shapeTensor,
555-
rhsShapeTensor = rhs.shapeTensor] v in
556-
var lhsGrad = v
557-
var rhsGrad = -v
558-
if lhsGrad.shape != lhsShape {
559-
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
560-
}
561-
if rhsGrad.shape != rhsShape {
562-
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
563-
}
564-
return (lhsGrad, rhsGrad)
543+
return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
544+
let lhsGrad = v
545+
let rhsGrad = -lhsGrad
546+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
547+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
548+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
565549
})
566550
}
567551
}

Sources/TensorFlow/Operators/Math.swift

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,12 @@ extension Tensor: VectorNumeric where Scalar: Numeric {
4343
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
4444
@inlinable
4545
static func _vjpMultiply(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
46-
return (lhs * rhs, { [
47-
lhsShape = lhs.shape,
48-
rhsShape = rhs.shape,
49-
lhsShapeTensor = lhs.shapeTensor,
50-
rhsShapeTensor = rhs.shapeTensor] v in
51-
var lhsGrad = rhs * v
52-
var rhsGrad = lhs * v
53-
if lhsGrad.shape != lhsShape {
54-
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
55-
}
56-
if rhsGrad.shape != rhsShape {
57-
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
58-
}
59-
return (lhsGrad, rhsGrad)
46+
return (lhs * rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
47+
let lhsGrad = rhs * v
48+
let rhsGrad = lhs * v
49+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
50+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
51+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
6052
})
6153
}
6254
}
@@ -236,12 +228,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
236228

237229
@inlinable
238230
static func _vjpSubtract(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) {
239-
return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) })
231+
return (lhs - rhs, { v in (v, -v.sum().scalarized()) })
240232
}
241233

242234
@inlinable
243235
static func _vjpSubtract(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) {
244-
return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) })
236+
return (lhs - rhs, { v in (v.sum().scalarized(), -v) })
245237
}
246238

247239
@inlinable
@@ -256,27 +248,19 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
256248

257249
@inlinable
258250
static func _vjpDivide(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
259-
return (lhs / rhs, { [
260-
lhsShape = lhs.shape,
261-
rhsShape = rhs.shape,
262-
lhsShapeTensor = lhs.shapeTensor,
263-
rhsShapeTensor = rhs.shapeTensor] v in
264-
var lhsGrad = v / rhs
265-
var rhsGrad = (-lhs) / rhs.squared() * v
266-
if lhsGrad.shape != lhsShape {
267-
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
268-
}
269-
if rhsGrad.shape != rhsShape {
270-
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
271-
}
272-
return (lhsGrad, rhsGrad)
251+
return (lhs / rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
252+
let lhsGrad = v / rhs
253+
let rhsGrad = -lhs / rhs.squared() * v
254+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
255+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
256+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
273257
})
274258
}
275259

276260
@inlinable
277261
static func _vjpDivide(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) {
278262
return (lhs / rhs, { v in
279-
(v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized())
263+
(v / rhs, (v * -lhs / Tensor(rhs).squared()).sum().scalarized())
280264
})
281265
}
282266

@@ -704,15 +688,12 @@ internal func _vjpPow<T: TensorFlowFloatingPoint>(
704688
let value = pow(x, y)
705689
return (value, { v in
706690
let safeX = x.replacing(with: Tensor<T>(onesLike: x), where: x .<= 0)
707-
var gradX = v * y * pow(x, y - 1)
708-
var gradY = value * v * log(safeX)
709-
if gradX.shape != x.shape {
710-
gradX = gradX.unbroadcasted(like: x)
711-
}
712-
if gradY.shape != y.shape {
713-
gradY = gradY.unbroadcasted(like: y)
714-
}
715-
return (gradX, gradY)
691+
let lhsGrad = v * y * pow(x, y - 1)
692+
let rhsGrad = value * v * log(safeX)
693+
let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor)
694+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
695+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
696+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
716697
})
717698
}
718699

@@ -798,15 +779,12 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
798779
seed: Tensor<T>
799780
) -> (Tensor<T>, Tensor<T>) {
800781
let denominator = 1 + Tensor<T>(x .== y)
801-
var gradX = seed * Tensor<T>(x .== originalValue) / denominator
802-
var gradY = seed * Tensor<T>(y .== originalValue) / denominator
803-
if gradX.shape != x.shape {
804-
gradX = gradX.unbroadcasted(like: x)
805-
}
806-
if gradY.shape != y.shape {
807-
gradY = gradY.unbroadcasted(like: y)
808-
}
809-
return (gradX, gradY)
782+
let lhsGrad = seed * Tensor<T>(x .== originalValue) / denominator
783+
let rhsGrad = seed * Tensor<T>(y .== originalValue) / denominator
784+
let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor)
785+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
786+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
787+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
810788
}
811789

812790
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,17 @@ final class MathOperatorTests: XCTestCase {
199199
XCTAssertEqual(0.816997, Double(prediction.scalars[0]), accuracy: 0.0001)
200200
}
201201

202+
func testBroadcastedAddGradient() {
203+
func foo(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
204+
return (x + y).sum()
205+
}
206+
let x = Tensor<Float>(ones: [1, 2, 1, 4])
207+
let y = Tensor<Float>(ones: [4, 1, 3, 1])
208+
let (dx, dy) = gradient(at: x, y, in: foo)
209+
XCTAssertEqual(x.shape, dx.shape)
210+
XCTAssertEqual(y.shape, dy.shape)
211+
}
212+
202213
static var allTests = [
203214
("testReduction", testReduction),
204215
("testArgmax", testArgmax),
@@ -209,6 +220,7 @@ final class MathOperatorTests: XCTestCase {
209220
("testMultiOpMath", testMultiOpMath),
210221
("testXWPlusB", testXWPlusB),
211222
("testXORInference", testXORInference),
212-
("testMLPClassifierStruct", testMLPClassifierStruct)
223+
("testMLPClassifierStruct", testMLPClassifierStruct),
224+
("testBroadcastedAddGradient", testBroadcastedAddGradient)
213225
]
214226
}

0 commit comments

Comments
 (0)