Skip to content

Commit 766a1de

Browse files
authored
[AutoDiff] Support differentiation of loops. (#25558)
- Change predecessor enum generation to support loops. - Generate indirect predecessor enums for BBs in loops. - Handle boxed payloads of indirect enums. - Traverse basic blocks in post-order post-dominance order. - This is necessary for computational correctness. - Add loop tests (for-in, while, repeat-while, break/continue, nested). Expose TF-584: incorrect derivative computation for repeat-while loops. Expose TF-585: AllocBoxToStack crash with `-O` for nested loop AD.
1 parent 1782d72 commit 766a1de

File tree

8 files changed

+504
-147
lines changed

8 files changed

+504
-147
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 161 additions & 126 deletions
Large diffs are not rendered by default.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//===--- Differentiation.h - SIL Automatic Differentiation ----*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// SWIFT_ENABLE_TENSORFLOW
14+
//
15+
// Reverse-mode automatic differentiation utilities.
16+
//
17+
// NOTE: Although the AD feature is developed as part of the Swift for
18+
// TensorFlow project, it is completely independent from TensorFlow support.
19+
//
20+
// TODO: Move definitions here from Differentiation.cpp.
21+
//
22+
//===----------------------------------------------------------------------===//
23+
24+
#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
25+
#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
26+
27+
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
28+
#include "swift/SILOptimizer/Utils/Local.h"
29+
30+
namespace swift {
31+
32+
using llvm::DenseMap;
33+
using llvm::SmallDenseMap;
34+
using llvm::SmallDenseSet;
35+
using llvm::SmallMapVector;
36+
using llvm::SmallSet;
37+
38+
/// Helper class for visiting basic blocks in post-order post-dominance order,
39+
/// based on a worklist algorithm.
40+
class PostOrderPostDominanceOrder {
41+
SmallVector<DominanceInfoNode *, 16> buffer;
42+
PostOrderFunctionInfo *postOrderInfo;
43+
size_t srcIdx = 0;
44+
45+
public:
46+
/// Constructor.
47+
/// \p root The root of the post-dominator tree.
48+
/// \p postOrderInfo The post-order info of the function.
49+
/// \p capacity Should be the number of basic blocks in the dominator tree to
50+
/// reduce memory allocation.
51+
PostOrderPostDominanceOrder(DominanceInfoNode *root,
52+
PostOrderFunctionInfo *postOrderInfo,
53+
int capacity = 0)
54+
: postOrderInfo(postOrderInfo) {
55+
buffer.reserve(capacity);
56+
buffer.push_back(root);
57+
}
58+
59+
/// Get the next block from the worklist.
60+
DominanceInfoNode *getNext() {
61+
if (srcIdx == buffer.size())
62+
return nullptr;
63+
return buffer[srcIdx++];
64+
}
65+
66+
/// Pushes the dominator children of a block onto the worklist in post-order.
67+
void pushChildren(DominanceInfoNode *node) {
68+
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
69+
}
70+
71+
/// Conditionally pushes the dominator children of a block onto the worklist
72+
/// in post-order.
73+
template <typename Pred>
74+
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
75+
SmallVector<DominanceInfoNode *, 4> children;
76+
for (auto *child : *node)
77+
children.push_back(child);
78+
llvm::sort(children.begin(), children.end(),
79+
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
80+
return postOrderInfo->getPONumber(n1->getBlock()) <
81+
postOrderInfo->getPONumber(n2->getBlock());
82+
});
83+
for (auto *child : children) {
84+
SILBasicBlock *childBB = child->getBlock();
85+
if (pred(childBB))
86+
buffer.push_back(child);
87+
}
88+
}
89+
};
90+
91+
} // end namespace swift
92+
93+
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,17 @@ extension Tracked : SignedNumeric & Numeric where T : SignedNumeric, T == T.Magn
114114
}
115115

116116
public static func *= (lhs: inout Tracked, rhs: Tracked) {
117-
lhs = Tracked(lhs.value * rhs.value)
117+
lhs = lhs * rhs
118+
}
119+
}
120+
121+
extension Tracked where T : FloatingPoint {
122+
public static func / (lhs: Tracked, rhs: Tracked) -> Tracked {
123+
return Tracked(lhs.value / rhs.value)
124+
}
125+
126+
public static func /= (lhs: inout Tracked, rhs: Tracked) {
127+
lhs = lhs / rhs
118128
}
119129
}
120130

@@ -181,6 +191,16 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
181191
}
182192
}
183193

194+
extension Tracked where T : Differentiable & FloatingPoint,
195+
T == T.AllDifferentiableVariables, T == T.TangentVector {
196+
@usableFromInline
197+
@differentiating(/)
198+
internal static func _vjpDivide(lhs: Self, rhs: Self)
199+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
200+
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
201+
}
202+
}
203+
184204
// Differential operators for `Tracked<Float>`.
185205
public extension Differentiable {
186206
@inlinable

test/AutoDiff/control_flow.swift

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,4 +519,112 @@ ControlFlowTests.test("Enums") {
519519
}
520520
}
521521

522+
ControlFlowTests.test("Loops") {
523+
func for_loop(_ x: Float) -> Float {
524+
var result = x
525+
for _ in 1..<3 {
526+
result = result * x
527+
}
528+
return result
529+
}
530+
expectEqual((8, 12), valueWithGradient(at: 2, in: for_loop))
531+
expectEqual((27, 27), valueWithGradient(at: 3, in: for_loop))
532+
533+
func while_loop(_ x: Float) -> Float {
534+
var result = x
535+
var i = 1
536+
while i < 3 {
537+
result = result * x
538+
i += 1
539+
}
540+
return result
541+
}
542+
expectEqual((8, 12), valueWithGradient(at: 2, in: while_loop))
543+
expectEqual((27, 27), valueWithGradient(at: 3, in: while_loop))
544+
545+
func repeat_while_loop(_ x: Float) -> Float {
546+
var result = x
547+
var i = 1
548+
repeat {
549+
result = result * x
550+
i += 1
551+
} while i < 3
552+
return result
553+
}
554+
// FIXME(TF-584): Investigate incorrect (too big) gradient values
555+
// for repeat-while loops.
556+
// expectEqual((8, 12), valueWithGradient(at: 2, in: repeat_while_loop))
557+
// expectEqual((27, 27), valueWithGradient(at: 3, in: repeat_while_loop))
558+
expectEqual((8, 18), valueWithGradient(at: 2, in: repeat_while_loop))
559+
expectEqual((27, 36), valueWithGradient(at: 3, in: repeat_while_loop))
560+
561+
func loop_continue(_ x: Float) -> Float {
562+
var result = x
563+
for i in 1..<10 {
564+
if i.isMultiple(of: 2) {
565+
continue
566+
}
567+
result = result * x
568+
}
569+
return result
570+
}
571+
expectEqual((64, 192), valueWithGradient(at: 2, in: loop_continue))
572+
expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_continue))
573+
574+
func loop_break(_ x: Float) -> Float {
575+
var result = x
576+
for i in 1..<10 {
577+
if i.isMultiple(of: 2) {
578+
continue
579+
}
580+
result = result * x
581+
}
582+
return result
583+
}
584+
expectEqual((64, 192), valueWithGradient(at: 2, in: loop_break))
585+
expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_break))
586+
587+
func nested_loop1(_ x: Float) -> Float {
588+
var outer = x
589+
for _ in 1..<3 {
590+
outer = outer * x
591+
592+
var inner = outer
593+
var i = 1
594+
while i < 3 {
595+
inner = inner + x
596+
i += 1
597+
}
598+
outer = inner
599+
}
600+
return outer
601+
}
602+
expectEqual((20, 22), valueWithGradient(at: 2, in: nested_loop1))
603+
expectEqual((104, 66), valueWithGradient(at: 4, in: nested_loop1))
604+
605+
func nested_loop2(_ x: Float, count: Int) -> Float {
606+
var outer = x
607+
outerLoop: for _ in 1..<count {
608+
outer = outer * x
609+
610+
var inner = outer
611+
var i = 1
612+
while i < count {
613+
inner = inner + x
614+
i += 1
615+
616+
switch Int(inner.truncatingRemainder(dividingBy: 7)) {
617+
case 0: break outerLoop
618+
case 1: break
619+
default: continue
620+
}
621+
}
622+
outer = inner
623+
}
624+
return outer
625+
}
626+
expectEqual((24, 12), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 5) }))
627+
expectEqual((16, 8), valueWithGradient(at: 4, in: { x in nested_loop2(x, count: 5) }))
628+
}
629+
522630
runAllTests()

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,57 @@ func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
3333
}
3434
}
3535

36+
// Test loops.
37+
38+
@differentiable
39+
func for_loop(_ x: Float) -> Float {
40+
var result: Float = x
41+
for _ in 0..<3 {
42+
result = result * x
43+
}
44+
return result
45+
}
46+
47+
@differentiable
48+
func while_loop(_ x: Float) -> Float {
49+
var result = x
50+
var i = 1
51+
while i < 3 {
52+
result = result * x
53+
i += 1
54+
}
55+
return result
56+
}
57+
58+
@differentiable
59+
func nested_loop(_ x: Float) -> Float {
60+
var outer = x
61+
for _ in 1..<3 {
62+
outer = outer * x
63+
64+
var inner = outer
65+
var i = 1
66+
while i < 3 {
67+
inner = inner / x
68+
i += 1
69+
}
70+
outer = inner
71+
}
72+
return outer
73+
}
74+
75+
// Test `try_apply`.
76+
77+
// expected-error @+1 {{function is not differentiable}}
78+
@differentiable
79+
// expected-note @+1 {{when differentiating this function definition}}
80+
func withoutDerivative<T : Differentiable, R: Differentiable>(
81+
at x: T, in body: (T) throws -> R
82+
) rethrows -> R {
83+
// expected-note @+1 {{differentiating control flow is not yet supported}}
84+
try body(x)
85+
}
86+
3687
// Test unsupported differentiation of active enum values.
3788

3889
// expected-error @+1 {{function is not differentiable}}
@@ -91,16 +142,14 @@ enum Tree : Differentiable & AdditiveArithmetic {
91142
}
92143
}
93144

94-
// Test loops.
95-
96145
// expected-error @+1 {{function is not differentiable}}
97146
@differentiable
98147
// expected-note @+1 {{when differentiating this function definition}}
99-
func loop(_ x: Float) -> Float {
148+
func loop_array(_ array: [Float]) -> Float {
100149
var result: Float = 1
101-
// expected-note @+1 {{differentiating loops is not yet supported}}
102-
for _ in 0..<3 {
103-
result += x
150+
// expected-note @+1 {{differentiating enum values is not yet supported}}
151+
for x in array {
152+
result = result * x
104153
}
105-
return x
154+
return result
106155
}

test/AutoDiff/control_flow_sil.swift

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,31 @@ func cond(_ x: Float) -> Float {
1616
return x - x
1717
}
1818

19-
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
20-
// CHECK-DATA-STRUCTURES: }
2119
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 {
2220
// CHECK-DATA-STRUCTURES: }
23-
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
24-
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
25-
// CHECK-DATA-STRUCTURES: }
2621
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 {
2722
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set }
2823
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set }
2924
// CHECK-DATA-STRUCTURES: }
30-
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
31-
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
32-
// CHECK-DATA-STRUCTURES: }
3325
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 {
3426
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set }
3527
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set }
3628
// CHECK-DATA-STRUCTURES: }
29+
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
30+
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
31+
// CHECK-DATA-STRUCTURES: }
32+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
33+
// CHECK-DATA-STRUCTURES: }
34+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
35+
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
36+
// CHECK-DATA-STRUCTURES: }
37+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
38+
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
39+
// CHECK-DATA-STRUCTURES: }
3740
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 {
3841
// CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0)
3942
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0)
4043
// CHECK-DATA-STRUCTURES: }
41-
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
42-
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
43-
// CHECK-DATA-STRUCTURES: }
4444

4545
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
4646
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
@@ -137,6 +137,20 @@ func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) ->
137137
return y
138138
}
139139

140+
@differentiable
141+
@_silgen_name("loop_generic")
142+
func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
143+
var result = x
144+
for _ in 1..<3 {
145+
var y = x
146+
for _ in 1..<3 {
147+
result = y
148+
y = result
149+
}
150+
}
151+
return result
152+
}
153+
140154
// Test control flow + tuple buffer.
141155
// Verify that adjoint buffers are not allocated for address projections.
142156

0 commit comments

Comments
 (0)