Skip to content

Commit 16dc9a2

Browse files
committed
[AutoDiff] Support differentiation of loops.
- Change predecessor enum generation to support loops. - Generate indirect predecessor enums for BBs in loops. - Handle boxed payloads of indirect enums. - Traverse basic blocks in post-order post-dominance order. - This is necessary for computational correctness. - Add loop tests (`for-in`, `while`, nested).
1 parent 8d39256 commit 16dc9a2

File tree

8 files changed

+419
-153
lines changed

8 files changed

+419
-153
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 158 additions & 126 deletions
Large diffs are not rendered by default.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//===--- Differentiation.h - SIL Automatic Differentiation ----*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// SWIFT_ENABLE_TENSORFLOW
14+
//
15+
// Reverse-mode automatic differentiation utilities.
16+
//
17+
// NOTE: Although the AD feature is developed as part of the Swift for
18+
// TensorFlow project, it is completely independent from TensorFlow support.
19+
//
20+
// TODO: Move definitions here from Differentiation.cpp.
21+
//
22+
//===----------------------------------------------------------------------===//
23+
24+
#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
25+
#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
26+
27+
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
28+
#include "swift/SILOptimizer/Utils/Local.h"
29+
30+
namespace swift {
31+
32+
using llvm::DenseMap;
33+
using llvm::SmallDenseMap;
34+
using llvm::SmallDenseSet;
35+
using llvm::SmallMapVector;
36+
using llvm::SmallSet;
37+
38+
/// Helper class for visiting basic blocks in post-order post-dominance order,
39+
/// based on a worklist algorithm.
40+
class PostOrderPostDominanceOrder {
41+
SmallVector<DominanceInfoNode *, 16> buffer;
42+
PostOrderFunctionInfo *postOrderInfo;
43+
size_t srcIdx = 0;
44+
45+
public:
46+
/// Constructor.
47+
/// \p root The root of the post-dominator tree.
48+
/// \p postOrderInfo The post-order info of the function.
49+
/// \p capacity Should be the number of basic blocks in the dominator tree to
50+
/// reduce memory allocation.
51+
PostOrderPostDominanceOrder(DominanceInfoNode *root,
52+
PostOrderFunctionInfo *postOrderInfo,
53+
int capacity = 0)
54+
: postOrderInfo(postOrderInfo) {
55+
buffer.reserve(capacity);
56+
buffer.push_back(root);
57+
}
58+
59+
/// Get the next block from the worklist.
60+
DominanceInfoNode *getNext() {
61+
if (srcIdx == buffer.size())
62+
return nullptr;
63+
return buffer[srcIdx++];
64+
}
65+
66+
/// Pushes the dominator children of a block onto the worklist in post-order.
67+
void pushChildren(DominanceInfoNode *node) {
68+
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
69+
}
70+
71+
/// Conditionally pushes the dominator children of a block onto the worklist
72+
/// in post-order.
73+
template <typename Pred>
74+
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
75+
SmallVector<DominanceInfoNode *, 4> children;
76+
for (auto *child : *node)
77+
children.push_back(child);
78+
llvm::sort(children.begin(), children.end(),
79+
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
80+
return postOrderInfo->getPONumber(n1->getBlock()) <
81+
postOrderInfo->getPONumber(n2->getBlock());
82+
});
83+
for (auto *child : children) {
84+
SILBasicBlock *childBB = child->getBlock();
85+
if (pred(childBB))
86+
buffer.push_back(child);
87+
}
88+
}
89+
};
90+
91+
} // end namespace swift
92+
93+
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
181181
}
182182
}
183183

184+
extension Tracked where T : Differentiable & FloatingPoint,
185+
T == T.AllDifferentiableVariables, T == T.TangentVector {
186+
@usableFromInline
187+
@differentiating(/)
188+
internal static func _vjpDivide(lhs: Self, rhs: Self)
189+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
190+
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
191+
}
192+
}
193+
184194
// Differential operators for `Tracked<Float>`.
185195
public extension Differentiable {
186196
@inlinable

test/AutoDiff/control_flow.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,4 +519,46 @@ ControlFlowTests.test("Enums") {
519519
}
520520
}
521521

522+
ControlFlowTests.test("Loops") {
523+
func for_loop(_ x: Float) -> Float {
524+
var result = x
525+
for _ in 1..<3 {
526+
result = result * x
527+
}
528+
return result
529+
}
530+
expectEqual((8, 12), valueWithGradient(at: 2, in: for_loop))
531+
expectEqual((27, 27), valueWithGradient(at: 3, in: for_loop))
532+
533+
func while_loop(_ x: Float) -> Float {
534+
var result = x
535+
var i = 1
536+
while i < 3 {
537+
result = result * x
538+
i += 1
539+
}
540+
return result
541+
}
542+
expectEqual((8, 12), valueWithGradient(at: 2, in: while_loop))
543+
expectEqual((27, 27), valueWithGradient(at: 3, in: while_loop))
544+
545+
func nested_loop(_ x: Float) -> Float {
546+
var outer = x
547+
for _ in 1..<3 {
548+
outer = outer * x
549+
550+
var inner = outer
551+
var i = 1
552+
while i < 3 {
553+
inner = inner / x
554+
i += 1
555+
}
556+
outer = inner
557+
}
558+
return outer
559+
}
560+
expectEqual((0.5, -0.25), valueWithGradient(at: 2, in: nested_loop))
561+
expectEqual((0.25, -0.0625), valueWithGradient(at: 4, in: nested_loop))
562+
}
563+
522564
runAllTests()

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,57 @@ func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
3333
}
3434
}
3535

36+
// Test loops.
37+
38+
@differentiable
39+
func for_loop(_ x: Float) -> Float {
40+
var result: Float = x
41+
for _ in 0..<3 {
42+
result = result * x
43+
}
44+
return result
45+
}
46+
47+
@differentiable
48+
func while_loop(_ x: Float) -> Float {
49+
var result = x
50+
var i = 1
51+
while i < 3 {
52+
result = result * x
53+
i += 1
54+
}
55+
return result
56+
}
57+
58+
@differentiable
59+
func nested_loop(_ x: Float) -> Float {
60+
var outer = x
61+
for _ in 1..<3 {
62+
outer = outer * x
63+
64+
var inner = outer
65+
var i = 1
66+
while i < 3 {
67+
inner = inner / x
68+
i += 1
69+
}
70+
outer = inner
71+
}
72+
return outer
73+
}
74+
75+
// Test `try_apply`.
76+
77+
// expected-error @+1 {{function is not differentiable}}
78+
@differentiable
79+
// expected-note @+1 {{when differentiating this function definition}}
80+
func withoutDerivative<T : Differentiable, R: Differentiable>(
81+
at x: T, in body: (T) throws -> R
82+
) rethrows -> R {
83+
// expected-note @+1 {{differentiating control flow is not yet supported}}
84+
try body(x)
85+
}
86+
3687
// Test unsupported differentiation of active enum values.
3788

3889
// expected-error @+1 {{function is not differentiable}}
@@ -90,17 +141,3 @@ enum Tree : Differentiable & AdditiveArithmetic {
90141
}
91142
}
92143
}
93-
94-
// Test loops.
95-
96-
// expected-error @+1 {{function is not differentiable}}
97-
@differentiable
98-
// expected-note @+1 {{when differentiating this function definition}}
99-
func loop(_ x: Float) -> Float {
100-
var result: Float = 1
101-
// expected-note @+1 {{differentiating loops is not yet supported}}
102-
for _ in 0..<3 {
103-
result += x
104-
}
105-
return x
106-
}

test/AutoDiff/control_flow_sil.swift

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,31 @@ func cond(_ x: Float) -> Float {
1616
return x - x
1717
}
1818

19-
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
20-
// CHECK-DATA-STRUCTURES: }
2119
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 {
2220
// CHECK-DATA-STRUCTURES: }
23-
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
24-
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
25-
// CHECK-DATA-STRUCTURES: }
2621
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 {
2722
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set }
2823
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set }
2924
// CHECK-DATA-STRUCTURES: }
30-
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
31-
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
32-
// CHECK-DATA-STRUCTURES: }
3325
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 {
3426
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set }
3527
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set }
3628
// CHECK-DATA-STRUCTURES: }
29+
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
30+
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
31+
// CHECK-DATA-STRUCTURES: }
32+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
33+
// CHECK-DATA-STRUCTURES: }
34+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
35+
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
36+
// CHECK-DATA-STRUCTURES: }
37+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
38+
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
39+
// CHECK-DATA-STRUCTURES: }
3740
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 {
3841
// CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0)
3942
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0)
4043
// CHECK-DATA-STRUCTURES: }
41-
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
42-
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
43-
// CHECK-DATA-STRUCTURES: }
4444

4545
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
4646
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
@@ -137,6 +137,20 @@ func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) ->
137137
return y
138138
}
139139

140+
@differentiable
141+
@_silgen_name("loop_generic")
142+
func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
143+
var result = x
144+
for _ in 1..<3 {
145+
var y = x
146+
for _ in 1..<3 {
147+
result = y
148+
y = result
149+
}
150+
}
151+
return result
152+
}
153+
140154
// Test control flow + tuple buffer.
141155
// Verify that adjoint buffers are not allocated for address projections.
142156

test/AutoDiff/leakchecking.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,44 @@ LeakCheckingTests.test("ControlFlow") {
190190
expectEqual((-2674, 2), Tracked<Float>(-1337).valueWithGradient(in: { x in enum_notactive2(.b(4, 5), x) }))
191191
}
192192

193+
// FIXME: Fix control flow AD memory leaks.
194+
// See related FIXME comments in adjoint value/buffer propagation in
195+
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
196+
testWithLeakChecking(expectedLeakCount: 6) {
197+
func for_loop(_ x: Tracked<Float>) -> Tracked<Float> {
198+
var result = x
199+
for _ in 1..<3 {
200+
result = result * x
201+
}
202+
return result
203+
}
204+
expectEqual((8, 12), Tracked<Float>(2).valueWithGradient(in: for_loop))
205+
expectEqual((27, 27), Tracked<Float>(3).valueWithGradient(in: for_loop))
206+
}
207+
208+
// FIXME: Fix control flow AD memory leaks.
209+
// See related FIXME comments in adjoint value/buffer propagation in
210+
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
211+
testWithLeakChecking(expectedLeakCount: 20) {
212+
func nested_loop(_ x: Tracked<Float>) -> Tracked<Float> {
213+
var outer = x
214+
for _ in 1..<3 {
215+
outer = outer * x
216+
217+
var inner = outer
218+
var i = 1
219+
while i < 3 {
220+
inner = inner / x
221+
i += 1
222+
}
223+
outer = inner
224+
}
225+
return outer
226+
}
227+
expectEqual((0.5, -0.25), Tracked<Float>(2).valueWithGradient(in: nested_loop))
228+
expectEqual((0.25, -0.0625), Tracked<Float>(4).valueWithGradient(in: nested_loop))
229+
}
230+
193231
// FIXME: Fix control flow AD memory leaks.
194232
// See related FIXME comments in adjoint value/buffer propagation in
195233
// lib/SILOptimizer/Mandatory/Differentiation.cpp.

test/AutoDiff/refcounting.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ func testOwnedVector(_ x: Vector) -> Vector {
3636
}
3737
_ = pullback(at: Vector.zero, in: testOwnedVector)
3838

39-
// CHECK-LABEL: enum {{.*}}testOwnedVector{{.*}}__Pred__src_0_wrt_0 {
40-
// CHECK-NEXT: }
4139
// CHECK-LABEL: struct {{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0 {
4240
// CHECK-NEXT: @_hasStorage var pullback_0: (Vector) -> (Vector, Vector) { get set }
4341
// CHECK-NEXT: }
42+
// CHECK-LABEL: enum {{.*}}testOwnedVector{{.*}}__Pred__src_0_wrt_0 {
43+
// CHECK-NEXT: }
4444

4545
// CHECK-LABEL: sil hidden @{{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__adjoint_src_0_wrt_0_1
4646
// CHECK: bb0([[SEED:%.*]] : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__PB__src_0_wrt_0_1):

0 commit comments

Comments
 (0)