Skip to content

Commit 045e192

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Support differentiation of conditionals. (#25057)
- Support control flow in adjoint generation. - Make adjoint value/buffer mappings be per basic block. - Change `AdjointValue` to not be move-only. Original values from different basic blocks may share the same `AdjointValue`. - Propagate adjoint values from active bb arguments to predecessor terminator operands. - Propagate adjoint values/buffers of dominated active values/buffers to predecessor blocks. - For active values: propagate adjoint values as adjoint bb arguments. - For active buffers: propagate adjoint buffers via `copy_addr`. - Revamp `AdjointEmitter` handling of `begin_access` and `end_access`. - `getAdjointValue` of `begin_access` now returns the adjoint base buffer. Previously, it returned a `begin_access` of the adjoint base buffer without generating a corresponding `end_access`. - `AdjointEmitter::visitBeginAccessInst` now generates no code. - `AdjointEmitter::visitEndAccessInst` now does nothing. - Add various control flow differentiation tests. - Test differentiation of conditionals (nested), recursion, `var` allocations (tuples, structs). - Add negative leakchecking tests. Todos: - Fix adjoint value/buffer propagation memory leaks. - Add more tests (adjoint SIL, leakchecking). - Support differentiation of enum-related instructions and loops.
1 parent fc896ca commit 045e192

File tree

8 files changed

+1115
-396
lines changed

8 files changed

+1115
-396
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 451 additions & 323 deletions
Large diffs are not rendered by default.

test/AutoDiff/control_flow.swift

Lines changed: 357 additions & 67 deletions
Large diffs are not rendered by default.

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s
2+
3+
// Test supported `br` and `cond_br` terminators.
4+
5+
@differentiable
6+
func branch(_ x: Float) -> Float {
7+
if x > 0 {
8+
return x
9+
} else if x < 10 {
10+
return x
11+
}
12+
return x
13+
}
214

315
// Test currently unsupported `switch_enum` terminator.
416

test/AutoDiff/control_flow_sil.swift

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
2+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s | %FileCheck %s -check-prefix=CHECK-SIL
3+
4+
// TODO: Add adjoint SIL FileCheck tests.
5+
6+
// Test conditional: a simple if-diamond.
7+
8+
@differentiable
9+
@_silgen_name("cond")
10+
func cond(_ x: Float) -> Float {
11+
if x > 0 {
12+
return x + x
13+
}
14+
return x - x
15+
}
16+
17+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
18+
// CHECK-DATA-STRUCTURES: }
19+
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 {
20+
// CHECK-DATA-STRUCTURES: }
21+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
22+
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
23+
// CHECK-DATA-STRUCTURES: }
24+
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 {
25+
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set }
26+
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set }
27+
// CHECK-DATA-STRUCTURES: }
28+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
29+
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
30+
// CHECK-DATA-STRUCTURES: }
31+
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 {
32+
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set }
33+
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set }
34+
// CHECK-DATA-STRUCTURES: }
35+
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 {
36+
// CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0)
37+
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0)
38+
// CHECK-DATA-STRUCTURES: }
39+
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
40+
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
41+
// CHECK-DATA-STRUCTURES: }
42+
43+
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0
44+
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
45+
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
46+
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
47+
// CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
48+
// CHECK-SIL: cond_br {{%.*}}, bb1([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0), bb2([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_
49+
50+
// CHECK-SIL: bb1([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0)
51+
// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0
52+
// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]]
53+
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
54+
55+
// CHECK-SIL: bb2([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0)
56+
// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0
57+
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]]
58+
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
59+
60+
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
61+
// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0
62+
// CHECK-SIL: [[ADJOINT_REF:%.*]] = function_ref @AD__cond__adjoint_src_0_wrt_0
63+
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[ADJOINT_REF]]([[BB3_PB_STRUCT]])
64+
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
65+
// CHECK-SIL: return [[VJP_RESULT]]
66+
67+
@differentiable
68+
@_silgen_name("nested_cond")
69+
func nested_cond(_ x: Float, _ y: Float) -> Float {
70+
if x > 0 {
71+
if y > 10 {
72+
return x * y
73+
} else {
74+
return x + y
75+
}
76+
}
77+
return y - x
78+
}
79+
80+
@differentiable
81+
@_silgen_name("nested_cond_generic")
82+
func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) -> T {
83+
if x > 0 {
84+
if y > 10 {
85+
return y
86+
} else {
87+
return x
88+
}
89+
}
90+
return y
91+
}

test/AutoDiff/leakchecking.swift

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-run-simple-swift
1+
// RUN: %target-run-simple-swift-control-flow-differentiation
22
// REQUIRES: executable_test
33

44
// A test that we can properly differentiate types that require refcounting.
@@ -8,6 +8,13 @@ import DifferentiationUnittest
88

99
var LeakCheckingTests = TestSuite("LeakChecking")
1010

11+
/// Execute body, check expected leak count, and reset global leak count.
12+
func testWithLeakChecking(expectedLeakCount: Int = 0, _ body: () -> Void) {
13+
body()
14+
expectEqual(expectedLeakCount, _GlobalLeakCount.count, "Leak detected.")
15+
_GlobalLeakCount.count = 0
16+
}
17+
1118
struct ExampleLeakModel : Differentiable {
1219
var bias: Tracked<Float> = 2.0
1320
func applied(to input: Tracked<Float>) -> Tracked<Float> {
@@ -22,7 +29,45 @@ LeakCheckingTests.test("BasicVarLeakChecking") {
2229
let x: Tracked<Float> = 1.0
2330
let _ = model.gradient(at: x) { m, x in m.applied(to: x) }
2431
}
25-
expectEqual(0, _GlobalLeakCount.count, "Leak Detected.")
32+
expectEqual(0, _GlobalLeakCount.count, "Leak detected.")
33+
}
34+
35+
LeakCheckingTests.test("ControlFlow") {
36+
// TODO: Add more `var` + control flow tests.
37+
// Porting tests from test/AutoDiff/control_flow.swift requires more support
38+
// for `Tracked<Float>`.
39+
40+
// FIXME: Fix control flow AD memory leaks.
41+
// See related FIXME comments in adjoint value/buffer propagation in
42+
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
43+
testWithLeakChecking(expectedLeakCount: 9) {
44+
var model = ExampleLeakModel()
45+
let x: Tracked<Float> = 1.0
46+
let _ = model.gradient(at: x) { m, x in
47+
let result: Tracked<Float>
48+
if x > 0 {
49+
result = m.applied(to: x)
50+
} else {
51+
result = x
52+
}
53+
return result
54+
}
55+
}
56+
57+
// FIXME: Fix control flow AD memory leaks.
58+
// See related FIXME comments in adjoint value/buffer propagation in
59+
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
60+
testWithLeakChecking(expectedLeakCount: 14) {
61+
var model = ExampleLeakModel()
62+
let x: Tracked<Float> = 1.0
63+
let _ = model.gradient(at: x) { m, x in
64+
var result: Tracked<Float> = x
65+
if x > 0 {
66+
result = result + m.applied(to: x)
67+
}
68+
return result
69+
}
70+
}
2671
}
2772

2873
runAllTests()

test/AutoDiff/refcounting.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,16 @@ _ = pullback(at: Vector.zero, in: testOwnedVector)
6060
// CHECK: return [[NEEDED_COTAN1]] : $Vector
6161

6262
// CHECK-LABEL: sil hidden @{{.*}}side_effect_release_zero{{.*}}__adjoint_src_0_wrt_0
63-
// CHECK: bb0([[X:%.*]] : $Vector, %1 : ${{.*}}side_effect_release_zero{{.*}}_bb0__PB__src_0_wrt_0):
64-
// CHECK: retain_value [[SEED:%.*]] : $Vector
63+
// CHECK: bb0([[SEED:%.*]] : $Vector, %1 : ${{.*}}side_effect_release_zero{{.*}}_bb0__PB__src_0_wrt_0):
6564
// CHECK: [[BUF:%.*]] = alloc_stack $Vector
6665
// CHECK: [[BUF_ACCESS:%.*]] = begin_access [init] [static] [no_nested_conflict] [[BUF]] : $*Vector
6766
// CHECK: [[ZERO_GETTER:%.*]] = function_ref @$s11refcounting6VectorV4zeroACvgZ
6867
// CHECK: [[ZERO:%.*]] = apply [[ZERO_GETTER]]({{%.*}}) : $@convention(method) (@thin Vector.Type) -> @owned Vector
6968
// CHECK: store [[ZERO]] to [[BUF_ACCESS]] : $*Vector
69+
// CHECK: retain_value [[SEED:%.*]] : $Vector
70+
// CHECK: release_value [[SEED:%.*]] : $Vector
7071
// CHECK: destroy_addr [[BUF]] : $*Vector
7172
// CHECK: dealloc_stack [[BUF]] : $*Vector
72-
// CHECK: release_value [[SEED:%.*]] : $Vector
7373
// CHECK: }
7474

7575
// The vjp should not release pullback values.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// RUN: %target-run-simple-swift-control-flow-differentiation
2+
// REQUIRES: executable_test
3+
//
4+
// FIXME(TF-326): Re-enable `-O` after deserialization failure fix.
5+
// UNSUPPORTED: swift_test_mode_optimize
6+
//
7+
// Tensor control flow AD runtime tests.
8+
// TODO: Move TensorFlow-specific AD tests into test/AutoDiff.
9+
10+
import TensorFlow
11+
import StdlibUnittest
12+
import TensorFlowUnittest
13+
14+
var TensorADTests = TestSuite("TensorControlFlowAD")
15+
16+
TensorADTests.testAllBackends("Conditionals") {
17+
func cond_nestedtuple_var(_ x: Tensor<Float>) -> Tensor<Float> {
18+
// Convoluted function returning `x + x`.
19+
var y: (Tensor<Float>, Tensor<Float>) = (x + x, x - x)
20+
var z: ((Tensor<Float>, Tensor<Float>), Tensor<Float>) = (y, x)
21+
if x > 0 {
22+
var w = (x, x)
23+
y.0 = w.1
24+
y.1 = w.0
25+
z.0.0 = z.0.0 - y.0
26+
z.0.1 = z.0.1 + y.0
27+
} else {
28+
z = ((y.0 - x, y.1 + x), x)
29+
}
30+
return y.0 + y.1 - z.0.0 + z.0.1
31+
}
32+
expectEqual((Tensor(8), Tensor(2)),
33+
valueWithGradient(at: Tensor(4), in: cond_nestedtuple_var))
34+
expectEqual((Tensor(-20), Tensor(2)),
35+
valueWithGradient(at: Tensor(-10), in: cond_nestedtuple_var))
36+
expectEqual((Tensor(-2674), Tensor(2)),
37+
valueWithGradient(at: Tensor(-1337), in: cond_nestedtuple_var))
38+
39+
func guard2_var(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
40+
var z = y
41+
guard x > 0 else {
42+
if y > 0 {
43+
z = z * x
44+
} else if x == Tensor(-1337) {
45+
z = x
46+
z = z * z
47+
} else {
48+
z = Tensor(0)
49+
}
50+
return z
51+
}
52+
return z * y
53+
}
54+
expectEqual((Tensor(0), Tensor(10)),
55+
gradient(at: Tensor(4), Tensor(5), in: guard2_var))
56+
expectEqual((Tensor(5), Tensor(-1337)),
57+
gradient(at: Tensor(-1337), Tensor(5), in: guard2_var))
58+
expectEqual((Tensor(-2674), Tensor(0)),
59+
gradient(at: Tensor(-1337), Tensor(-5), in: guard2_var))
60+
expectEqual((Tensor(2), Tensor(-3)),
61+
gradient(at: Tensor(-3), Tensor(2), in: guard2_var))
62+
}
63+
64+
TensorADTests.testAllBackends("NestedConditionals") {
65+
// Test tensor-tensor ops.
66+
func cond_nested1(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
67+
if x > 0 {
68+
if y > 10 {
69+
let z = x * y
70+
if z > 100 {
71+
return x + z
72+
} else if y == Tensor(20) {
73+
return z + z
74+
}
75+
} else {
76+
return x + y
77+
}
78+
}
79+
return -y
80+
}
81+
82+
expectEqual((Tensor(40), Tensor(8)),
83+
gradient(at: Tensor(4), Tensor(20), in: cond_nested1))
84+
expectEqual((Tensor(0), Tensor(-1)),
85+
gradient(at: Tensor(4), Tensor(21), in: cond_nested1))
86+
expectEqual((Tensor(1), Tensor(1)),
87+
gradient(at: Tensor(4), Tensor(5), in: cond_nested1))
88+
expectEqual((Tensor(0), Tensor(-1)),
89+
gradient(at: Tensor(-3), Tensor(-2), in: cond_nested1))
90+
91+
// Test tensor-scalar ops.
92+
func cond_nested2(_ x: Tensor<Float>, _ y: Float) -> Tensor<Float> {
93+
if x > 0 {
94+
if y > 10 {
95+
let z = x * y
96+
if z > 100 {
97+
return x + z
98+
} else if y == 20 {
99+
return z + z
100+
}
101+
} else {
102+
return x + y
103+
}
104+
}
105+
return Tensor(-y)
106+
}
107+
108+
expectEqual((Tensor(40), 8), gradient(at: Tensor(4), 20, in: cond_nested2))
109+
expectEqual((Tensor(0), -1), gradient(at: Tensor(4), 21, in: cond_nested2))
110+
expectEqual((Tensor(1), 1), gradient(at: Tensor(4), 5, in: cond_nested2))
111+
expectEqual((Tensor(0), -1), gradient(at: Tensor(-3), -2, in: cond_nested2))
112+
}
113+
114+
TensorADTests.testAllBackends("Recursion") {
115+
func factorial(_ x: Tensor<Float>) -> Tensor<Float> {
116+
if x == Tensor(1) {
117+
return Tensor(1)
118+
}
119+
return x * factorial(x - 1)
120+
}
121+
expectEqual(Tensor(0), gradient(at: Tensor(1), in: factorial))
122+
expectEqual(Tensor(1), gradient(at: Tensor(2), in: factorial))
123+
expectEqual(Tensor(5), gradient(at: Tensor(3), in: factorial))
124+
expectEqual(Tensor(26), gradient(at: Tensor(4), in: factorial))
125+
expectEqual(Tensor(154), gradient(at: Tensor(5), in: factorial))
126+
127+
func product(_ x: Tensor<Float>, count: Int) -> Tensor<Float> {
128+
precondition(count > 0)
129+
if count == 1 {
130+
return x
131+
}
132+
return x * product(x, count: count - 1)
133+
}
134+
expectEqual(Tensor(300),
135+
gradient(at: Tensor(10), in: { x in product(x, count: 3) }))
136+
expectEqual(Tensor(-20),
137+
gradient(at: Tensor(-10), in: { x in product(x, count: 2) }))
138+
expectEqual(Tensor(1),
139+
gradient(at: Tensor(100), in: { x in product(x, count: 1) }))
140+
}
141+
142+
runAllTests()

test/lit.cfg

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,6 +1419,14 @@ if not getattr(config, 'target_run_simple_swift', None):
14191419
'%s %%t/a.out' % (config.target_build_swift,
14201420
mcp_opt, config.target_codesign,
14211421
config.target_run)))
1422+
# SWIFT_ENABLE_TENSORFLOW
1423+
# TODO: Remove when differentiation control flow support is robust.
1424+
config.target_run_simple_swift_control_flow_differentiation = (
1425+
'%%empty-directory(%%t) && '
1426+
'%s %s %%s -Xllvm -differentiation-enable-control-flow -o %%t/a.out %s -module-name main && '
1427+
'%s %%t/a.out &&'
1428+
'%s %%t/a.out'
1429+
% (config.target_build_swift, mcp_opt, swift_tensorflow_extra_options, config.target_codesign, config.target_run))
14221430
config.target_run_simple_swift = (
14231431
'%%empty-directory(%%t) && '
14241432
'%s %s %%s -o %%t/a.out %s -module-name main && '
@@ -1476,6 +1484,9 @@ config.substitutions.append(('%target-swift-frontend', config.target_swift_front
14761484

14771485

14781486
config.substitutions.append(('%target-run-simple-swiftgyb', config.target_run_simple_swiftgyb))
1487+
# SWIFT_ENABLE_TENSORFLOW
1488+
# TODO: Remove when differentiation control flow support is robust.
1489+
config.substitutions.append(('%target-run-simple-swift-control-flow-differentiation', config.target_run_simple_swift_control_flow_differentiation))
14791490
config.substitutions.append(('%target-run-simple-swift\(([^)]+)\)', config.target_run_simple_swift_parameterized))
14801491
config.substitutions.append(('%target-run-simple-swift', config.target_run_simple_swift))
14811492
config.substitutions.append(('%target-run-stdlib-swiftgyb', config.target_run_stdlib_swiftgyb))

0 commit comments

Comments
 (0)