Skip to content

Commit c4bac79

Browse files
bartchr808rxwei
authored andcommitted
[AutoDiff] TF-509: Make Tensor.broadcast(to:) differentiable (#24859)
* Add all VJP functions, need to write tests. * PR feedback batch #1. * Use closure call to remove VJPs * Start adding tests (un)broadcast(toShape:).
1 parent ff2818a commit c4bac79

File tree

3 files changed

+166
-2
lines changed

3 files changed

+166
-2
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,27 @@ func _vjpRelu<T : TensorFlowFloatingPoint>(
635635
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
636636
return (relu(x), { v in Tensor(x .> 0) * v })
637637
}
638+
639+
//===----------------------------------------------------------------------===//
640+
// Broadcasting
641+
//===----------------------------------------------------------------------===//
642+
643+
extension Tensor where Scalar : TensorFlowFloatingPoint {
644+
@inlinable
645+
func _vjpBroadcast(
646+
toShape shape: Tensor<Int32>
647+
) -> (Tensor, (Tensor) -> Tensor) {
648+
return (broadcast(toShape: shape), { [origShape = self.shapeTensor] v in
649+
v.unbroadcast(toShape: origShape)
650+
})
651+
}
652+
653+
@inlinable
654+
func _vjpUnbroadcast(
655+
toShape shape: Tensor<Int32>
656+
) -> (Tensor, (Tensor) -> Tensor) {
657+
return (unbroadcast(toShape: shape), { [origShape = self.shapeTensor] v in
658+
v.broadcast(toShape: origShape)
659+
})
660+
}
661+
}

stdlib/public/TensorFlow/Ops.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,25 +1600,32 @@ public extension Tensor {
16001600

16011601
public extension Tensor {
16021602
@inlinable
1603+
@differentiable(wrt: self, vjp: _vjpBroadcast(toShape:)
1604+
where Scalar : TensorFlowFloatingPoint)
16031605
func broadcast(toShape shape: Tensor<Int32>) -> Tensor {
16041606
return Raw.broadcastTo(self, shape: shape)
16051607
}
16061608

16071609
@inlinable
1610+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
16081611
func broadcast(to shape: TensorShape) -> Tensor {
1609-
return broadcast(toShape: Tensor<Int32>(shape.dimensions.map(Int32.init)))
1612+
return broadcast(toShape: Tensor<Int32>({ shape.dimensions.map(Int32.init) }()))
16101613
}
16111614

16121615
/// Broadcast to the same shape as the specified `Tensor`.
16131616
/// - Precondition: The specified shape must be compatible for broadcasting.
16141617
@inlinable
1618+
@differentiable(wrt: self
1619+
where Scalar : TensorFlowFloatingPoint)
16151620
func broadcast<OtherScalar>(like other: Tensor<OtherScalar>) -> Tensor {
16161621
return broadcast(toShape: other.shapeTensor)
16171622
}
16181623
}
16191624

16201625
public extension Tensor where Scalar : Numeric {
16211626
@inlinable
1627+
@differentiable(wrt: self, vjp: _vjpUnbroadcast(toShape:)
1628+
where Scalar : TensorFlowFloatingPoint)
16221629
func unbroadcast(toShape otherShape: Tensor<Int32>) -> Tensor {
16231630
let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted()
16241631
let ones: Tensor<Int32> = Raw.fill(dims: rankDiff, value: Tensor<Int32>(1))
@@ -1631,13 +1638,15 @@ public extension Tensor where Scalar : Numeric {
16311638
}
16321639

16331640
@inlinable
1641+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
16341642
func unbroadcast<OtherScalar>(like other: Tensor<OtherScalar>) -> Tensor {
16351643
return unbroadcast(toShape: other.shapeTensor)
16361644
}
16371645

16381646
@inlinable
1647+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
16391648
func unbroadcast(to shape: TensorShape) -> Tensor {
1640-
return unbroadcast(toShape: Tensor<Int32>(shape.dimensions.map(Int32.init)))
1649+
return unbroadcast(toShape: Tensor<Int32>({ shape.dimensions.map(Int32.init) }()))
16411650
}
16421651

16431652
@inlinable

test/TensorFlowRuntime/tensor_autodiff_runtime.swift

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,135 @@ TensorADTests.testAllBackends("Side effects") {
262262
expectEqual(Tensor(48), gradient(at: Tensor(4), in: bar))
263263
}
264264

265+
TensorADTests.testAllBackends("broadcast(toShape:)") {
266+
func foo(tensor: Tensor<Float>, shape: Tensor<Int32>) -> Tensor<Float> {
267+
tensor.broadcast(toShape: shape)
268+
}
269+
270+
var inputTensor: Tensor<Float>
271+
var expected: Tensor<Float>
272+
var pb: (Tensor<Float>) -> Tensor<Float>
273+
274+
// [3,] -> [3,3]
275+
pb = pullback(at: Tensor([99, 33, 55])) { x in
276+
foo(tensor: x, shape: Tensor([3, 3]))
277+
}
278+
279+
// Test 1: same shape as parameter of pullback
280+
inputTensor = Tensor([
281+
[1, 2, 3],
282+
[1, 2, 3],
283+
[1, 2, 3]]
284+
)
285+
expected = Tensor([3, 6, 9])
286+
expectEqual(expected, pb(inputTensor))
287+
288+
// Test 2: different shape than parameter of pullback
289+
inputTensor = Tensor([
290+
[1, 2, 3],
291+
[1, 2, 3],
292+
[1, 2, 3],
293+
[1, 2, 3]]
294+
)
295+
expected = Tensor([4, 8, 12])
296+
expectEqual(expected, pb(inputTensor))
297+
298+
// Test 3: same shape as tensor we are differentiating at
299+
inputTensor = Tensor([1, 2, 3])
300+
expected = Tensor([1, 2, 3])
301+
expectEqual(expected, pb(inputTensor))
302+
303+
// Test 4: extremely padded shape as tensor we are differentiating at
304+
inputTensor = Tensor([[[[[[1, 2, 3]]]]]])
305+
expected = Tensor([1, 2, 3])
306+
expectEqual(expected, pb(inputTensor))
307+
308+
// [3,1] -> [3x3]
309+
pb = pullback(at: Tensor([[99, 33, 55]])) { x in
310+
foo(tensor: x, shape: Tensor([3, 3]))
311+
}
312+
313+
// Test 5: same shape as parameter of pullback
314+
inputTensor = Tensor([
315+
[1, 2, 3],
316+
[1, 2, 3],
317+
[1, 2, 3]]
318+
)
319+
expected = Tensor([[3, 6, 9]])
320+
expectEqual(expected, pb(inputTensor))
321+
322+
// Test 6: different shape than parameter of pullback
323+
inputTensor = Tensor([
324+
[1, 2, 3],
325+
[1, 2, 3],
326+
[1, 2, 3],
327+
[1, 2, 3]]
328+
)
329+
expected = Tensor([[4, 8, 12]])
330+
expectEqual(expected, pb(inputTensor))
331+
332+
// Test 7: same shape as tensor we are differentiating at
333+
inputTensor = Tensor([[1, 2, 3]])
334+
expected = Tensor([[1, 2, 3]])
335+
expectEqual(expected, pb(inputTensor))
336+
337+
// Test 8: extremely padded shape of tensor we are differentiating at
338+
inputTensor = Tensor([[[[[[1, 2, 3]]]]]])
339+
expected = Tensor([[1, 2, 3]])
340+
expectEqual(expected, pb(inputTensor))
341+
}
342+
343+
TensorADTests.testAllBackends("unbroadcast(toShape:") {
344+
func foo(tensor: Tensor<Float>, shape: Tensor<Int32>) -> Tensor<Float> {
345+
tensor.unbroadcast(toShape: shape)
346+
}
347+
348+
var inputTensor: Tensor<Float>
349+
var expected: Tensor<Float>
350+
var pb: (Tensor<Float>) -> Tensor<Float>
351+
352+
// [3,3] -> [1,3]
353+
let atTensor: Tensor<Float> = Tensor([
354+
[1, 2, 3],
355+
[1, 2, 3],
356+
[1, 2, 3]]
357+
)
358+
pb = pullback(at: atTensor) { x in
359+
foo(tensor: x, shape: Tensor([1, 3]))
360+
}
361+
362+
// Test 1: same shape as parameter of pullback
363+
inputTensor = Tensor([[1, 2, 3]])
364+
expected = atTensor
365+
expectEqual(expected, pb(inputTensor))
366+
367+
// Test 2: different shape than parameter of pullback
368+
inputTensor = Tensor([2])
369+
expected = Tensor([
370+
[2, 2, 2],
371+
[2, 2, 2],
372+
[2, 2, 2]]
373+
)
374+
expectEqual(expected, pb(inputTensor))
375+
376+
// Test 3: same shape as tensor we are differentiating at
377+
inputTensor = Tensor([
378+
[8, 1, 3],
379+
[8, 1, 3],
380+
[8, 1, 3]]
381+
)
382+
expected = inputTensor
383+
expectEqual(expected, pb(inputTensor))
384+
385+
// TODO
386+
// Test 4: extremely padded shape as tensor we are differentiating at
387+
// inputTensor = Tensor([
388+
// [[8, 1, 3]],
389+
// [[8, 1, 3]],
390+
// [[8, 1, 3]]]
391+
// )
392+
// expected = Tensor([1, 2, 3])
393+
// expectEqual(expected, pb(inputTensor))
394+
}
395+
265396
runAllTests()

0 commit comments

Comments
 (0)