Skip to content

[TF] Remove unbroadcast(to:) and improve derivative performance. #24408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions stdlib/public/TensorFlow/Gradients.swift
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs + rhs, {
[lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
return (v.unbroadcast(toShape: lhsShape), v.unbroadcast(toShape: rhsShape))
let (lhsAxes, rhsAxes) =
Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (v.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
v.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}

Expand All @@ -220,30 +223,38 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs - rhs, {
[lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
return (v.unbroadcast(toShape: lhsShape),
-v.unbroadcast(toShape: rhsShape))
let (lhsAxes, rhsAxes) =
Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (v.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked too much, but I suspect that this extra reshape is not necessary. The lhsAxes should be sufficient to recover the original shape.

Copy link
Contributor Author

@rxwei rxwei May 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s what I tried initially (more specifically, ‘sum(alongAxes:)’) but it didn’t work.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pschuh @rxwei the reshape is needed for handling dimensions with size 1. For example, say you do:

// x has shape [B, 5]
// y has shape [5]
// result has shape [B, 5]
let result = x + y

In this case, the broadcast indices for the gradient wrt to y will be [0] and so we’ll do something like:

let yGrad = seed.sum(alongAxes: [0]) // no reshape needed.

Now, let y have shape [1, 5], which still broadcasts correctly for this example. The broadcast indices will now also be the same for the gradient (i.e., [0]). However, we need to do the reshape to recover the dimensions of size 1. Thus, the gradient needs to be computed as:

let yGrad = seed.sum(alongAxes: [0]).reshape(to: y.shape)

Having said that, I have a working implementation of these changes that I had made as part of a future swift-apis PR. I’ll try to open a PR here for this ASAP, but haven’t gotten the chance yet because I’m traveling to ICLR this week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure we don't regress in the future, could you add a quick test case in your other PR to swift-apis? :-)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, I will go ahead and add that. Given that the merge already happened, is it ok to make this change after we move stdlib to swift-apis? I'll update the two PRs doing the move tonight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks!

-v.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}

@inlinable
static func _vjpMultiply(
lhs: Tensor, rhs: Tensor
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, {
[lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
((rhs * v).unbroadcast(toShape: lhsShape),
(lhs * v).unbroadcast(toShape: rhsShape))
return (lhs * rhs, { v in
let (lhsShape, rhsShape) = (lhs.shapeTensor, rhs.shapeTensor)
let (lhsAxes, rhsAxes) =
Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return ((rhs * v).sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
(lhs * v).sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}

@inlinable
static func _vjpDivide(
lhs: Tensor, rhs: Tensor
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs / rhs, {
[lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
((v / rhs).unbroadcast(toShape: lhsShape),
((-lhs) / rhs.squared() * v).unbroadcast(toShape: rhsShape))
return (lhs / rhs, { v in
let (lhsShape, rhsShape) = (lhs.shapeTensor, rhs.shapeTensor)
let (lhsAxes, rhsAxes) =
Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return ((v / rhs).sum(squeezingAxes: lhsAxes)
.reshaped(toShape: lhsShape),
(-lhs / rhs.squared() * v).sum(squeezingAxes: rhsAxes)
.reshaped(toShape: rhsShape))
})
}
}
Expand All @@ -267,14 +278,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
static func _vjpSubtract(
lhs: Tensor, rhs: Scalar
) -> (Tensor, (Tensor) -> (Tensor, Scalar)) {
return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) })
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is some legacy code introduced in the early days when a where clause on @differentiable was not supported. Now it is fixed for better.

return (lhs - rhs, { v in (v, -v.sum().scalarized()) })
}

@inlinable
static func _vjpSubtract(
lhs: Scalar, rhs: Tensor
) -> (Tensor, (Tensor) -> (Scalar, Tensor)) {
return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) })
return (lhs - rhs, { v in (v.sum().scalarized(), -v) })
}

@inlinable
Expand All @@ -296,7 +307,7 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
lhs: Tensor, rhs: Scalar
) -> (Tensor, (Tensor) -> (Tensor, Scalar)) {
return (lhs / rhs, { v in
(v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized())
(v / rhs, (v * -lhs / Tensor(rhs).squared()).sum().scalarized())
})
}

Expand All @@ -317,25 +328,30 @@ func _vjpMinMaxHelper<T : TensorFlowFloatingPoint>(
let denom = 1 + Tensor<T>(x .== y)
let dfdx = vector * Tensor<T>(x .== originalValue) / denom
let dfdy = vector * Tensor<T>(y .== originalValue) / denom
return (dfdx.unbroadcast(like: x), dfdy.unbroadcast(like: y))
let (xShape, yShape) = (x.shapeTensor, y.shapeTensor)
let (xAxes, yAxes) = Raw.broadcastGradientArgs(s0: xShape, s1: yShape)
return (dfdx.sum(squeezingAxes: xAxes).reshaped(toShape: xShape),
dfdy.sum(squeezingAxes: yAxes).reshaped(toShape: yShape))
}

@inlinable
func _vjpMax<T : TensorFlowFloatingPoint>(
_ x: Tensor<T>, _ y: Tensor<T>
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
let value = max(x, y)
return (value,
{ v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) })
return (value, { v in
_vjpMinMaxHelper(x, y, originalValue: value, vector: v)
})
}

@inlinable
func _vjpMin<T : TensorFlowFloatingPoint>(
_ x: Tensor<T>, _ y: Tensor<T>
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
let value = min(x, y)
return (value,
{ v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) })
return (value, { v in
_vjpMinMaxHelper(x, y, originalValue: value, vector: v)
})
}

@inlinable
Expand All @@ -344,8 +360,12 @@ func _vjpPow<T : TensorFlowFloatingPoint>(
) -> (Tensor<T>, (Tensor<T>) -> (Tensor<T>, Tensor<T>)) {
let value = pow(x, y)
return (value, { v in
((v * y * pow(x, y-1)).unbroadcast(like: x),
(v * log(x) * value).unbroadcast(like: y))
let (xShape, yShape) = (x.shapeTensor, y.shapeTensor)
let (xAxes, yAxes) = Raw.broadcastGradientArgs(s0: xShape, s1: yShape)
return ((v * y * pow(x, y-1)).sum(squeezingAxes: xAxes)
.reshaped(toShape: xShape),
(v * log(x) * value).sum(squeezingAxes: yAxes)
.reshaped(toShape: yShape))
})
}

Expand Down
24 changes: 0 additions & 24 deletions stdlib/public/TensorFlow/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1665,30 +1665,6 @@ public extension Tensor {
func broadcast<OtherScalar>(like other: Tensor<OtherScalar>) -> Tensor {
return broadcast(toShape: other.shapeTensor)
}
}

public extension Tensor where Scalar : Numeric {
@inlinable
func unbroadcast(toShape otherShape: Tensor<Int32>) -> Tensor {
let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted()
let ones: Tensor<Int32> = Raw.fill(dims: rankDiff, value: Tensor<Int32>(1))
let paddedShape = ones ++ otherShape
let nonEqualIndices = paddedShape .!= shapeTensor
let broadcastIndices = Raw.where_(nonEqualIndices).flattened()
let unbroadcasted: Tensor = Raw.sum(
self, reductionIndices: Tensor<Int32>(broadcastIndices), keepDims: false)
return Raw.reshape(unbroadcasted, shape: otherShape)
}

@inlinable @inline(__always)
func unbroadcast<OtherScalar>(like other: Tensor<OtherScalar>) -> Tensor {
return unbroadcast(toShape: other.shapeTensor)
}

@inlinable @inline(__always)
func unbroadcast(to shape: TensorShape) -> Tensor {
return unbroadcast(toShape: Tensor<Int32>(shape.dimensions.map(Int32.init)))
}

@inlinable @inline(__always)
static func .= (lhs: inout Tensor, rhs: Tensor) {
Expand Down
14 changes: 14 additions & 0 deletions test/TensorFlowRuntime/tensor_autodiff_runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ TensorADTests.testAllBackends("TestSimpleGrad") {
expectEqual([[20], [40]], gradient(at: [[10], [20]], in: square))
}

// TODO: This is also failing!
TensorADTests.testAllBackends("TestBroadcastingGrad") {
func foo(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
return x * y + x
}
let x = Tensor<Float>(ones: [1, 2, 1, 4])
let y = Tensor<Float>(ones: [4, 1, 3, 1])
let (dx, dy) = gradient(at: x, y, in: foo)
expectEqual(x.shape, dx.shape)
expectEqual(y.shape, dx.shape)
}

TensorADTests.testAllBackends("TestGenericGrad") {
func square<T : TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
return x * x
Expand Down Expand Up @@ -219,13 +231,15 @@ TensorADTests.testAllBackends("Differentiate global") {
}

TensorADTests.testAllBackends("Side effects") {
///* This is failing reshape for some reason
let foo: @differentiable (Tensor<Float>) -> Tensor<Float> = { x in
var a = x
a = a + x
a = a + x
return a + x
}
expectEqual(Tensor([8, 8]), pullback(at: Tensor(4), in: foo)([1, 1]))
//*/

func bar(x: Tensor<Float>) -> Tensor<Float> {
var a = x
Expand Down