Skip to content

Commit 528fb67

Browse files
authored
[TF] Reimplement unbroadcast using on-host axis calculation for performance. (#24907)
The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available. This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`. With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)). Note: - Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method. - The part of #24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster. TODO: - Change `unbroadcast(toShape:)` tests added by #24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
1 parent 9e20d2e commit 528fb67

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -645,16 +645,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
645645
func _vjpBroadcast(
646646
toShape shape: Tensor<Int32>
647647
) -> (Tensor, (Tensor) -> Tensor) {
648-
return (broadcast(toShape: shape), { [origShape = self.shapeTensor] v in
648+
return (broadcast(toShape: shape), { [origShape = shapeTensor] v in
649649
v.unbroadcast(toShape: origShape)
650650
})
651651
}
652652

653653
@inlinable
654-
func _vjpUnbroadcast(
655-
toShape shape: Tensor<Int32>
656-
) -> (Tensor, (Tensor) -> Tensor) {
657-
return (unbroadcast(toShape: shape), { [origShape = self.shapeTensor] v in
654+
func _vjpUnbroadcast(to shape: TensorShape) -> (Tensor, (Tensor) -> Tensor) {
655+
return (unbroadcast(to: shape), { [origShape = shapeTensor] v in
658656
v.broadcast(toShape: origShape)
659657
})
660658
}

stdlib/public/TensorFlow/Ops.swift

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,40 +1601,36 @@ public extension Tensor {
16011601
public extension Tensor {
16021602
@inlinable
16031603
@differentiable(wrt: self, vjp: _vjpBroadcast(toShape:)
1604-
where Scalar : TensorFlowFloatingPoint)
1604+
where Scalar : TensorFlowFloatingPoint)
16051605
func broadcast(toShape shape: Tensor<Int32>) -> Tensor {
16061606
return Raw.broadcastTo(self, shape: shape)
16071607
}
16081608

16091609
@inlinable
16101610
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
16111611
func broadcast(to shape: TensorShape) -> Tensor {
1612-
return broadcast(toShape: Tensor<Int32>({ shape.dimensions.map(Int32.init) }()))
1612+
return broadcast(
1613+
toShape: Tensor<Int32>({ shape.dimensions.map(Int32.init) }()))
16131614
}
16141615

16151616
/// Broadcast to the same shape as the specified `Tensor`.
16161617
/// - Precondition: The specified shape must be compatible for broadcasting.
16171618
@inlinable
1618-
@differentiable(wrt: self
1619-
where Scalar : TensorFlowFloatingPoint)
1619+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
16201620
func broadcast<OtherScalar>(like other: Tensor<OtherScalar>) -> Tensor {
16211621
return broadcast(toShape: other.shapeTensor)
16221622
}
16231623
}
16241624

16251625
public extension Tensor where Scalar : Numeric {
16261626
@inlinable
1627-
@differentiable(wrt: self, vjp: _vjpUnbroadcast(toShape:)
1628-
where Scalar : TensorFlowFloatingPoint)
1627+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
16291628
func unbroadcast(toShape otherShape: Tensor<Int32>) -> Tensor {
1630-
let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted()
1631-
let ones: Tensor<Int32> = Raw.fill(dims: rankDiff, value: Tensor<Int32>(1))
1632-
let paddedShape = ones ++ otherShape
1633-
let nonEqualIndices = paddedShape .!= shapeTensor
1634-
let broadcastIndices = Raw.where_(nonEqualIndices).flattened()
1635-
let unbroadcasted: Tensor = Raw.sum(
1636-
self, reductionIndices: Tensor<Int32>(broadcastIndices), keepDims: false)
1637-
return Raw.reshape(unbroadcasted, shape: otherShape)
1629+
// TODO: Simplify this once differentiating control flow is supported.
1630+
return unbroadcast(to: {
1631+
precondition(otherShape.rank == 1)
1632+
return TensorShape(otherShape.scalars.map(Int.init))
1633+
}())
16381634
}
16391635

16401636
@inlinable
@@ -1644,9 +1640,31 @@ public extension Tensor where Scalar : Numeric {
16441640
}
16451641

16461642
@inlinable
1647-
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
1643+
@differentiable(wrt: self, vjp: _vjpUnbroadcast(to:)
1644+
where Scalar : TensorFlowFloatingPoint)
16481645
func unbroadcast(to shape: TensorShape) -> Tensor {
1649-
return unbroadcast(toShape: Tensor<Int32>({ shape.dimensions.map(Int32.init) }()))
1646+
let dimensions = self.shape.dimensions
1647+
var otherDimensions = shape.dimensions
1648+
let rankDifference = dimensions.count - otherDimensions.count
1649+
precondition(rankDifference >= 0, """
1650+
The rank of 'self' must be greater than or equal to the number of \
1651+
dimensions in the destination shape
1652+
""")
1653+
if rankDifference > 0 {
1654+
otherDimensions.insert(
1655+
contentsOf: repeatElement(1, count: rankDifference),
1656+
at: 0
1657+
)
1658+
}
1659+
assert(dimensions.count == otherDimensions.count)
1660+
var axes: [Int] = []
1661+
axes.reserveCapacity(dimensions.count)
1662+
for (i, (dim, otherDim)) in zip(dimensions, otherDimensions).enumerated() {
1663+
if dim == otherDim { continue }
1664+
if otherDim == 1 { axes.append(i); continue }
1665+
preconditionFailure("Cannot unbroadcast \(self.shape) to \(shape)")
1666+
}
1667+
return sum(alongAxes: axes).reshaped(to: shape)
16501668
}
16511669

16521670
@inlinable

0 commit comments

Comments
 (0)