Skip to content

Commit 5dd9c50

Browse files
authored
[AutoDiff] Move Differentiation before OwnershipModelEliminator. (swiftlang#26157)
Moves Differentiation before OwnershipModelEliminator. Now Differentiation happens right after DefiniteInitialization. ## Changes ### Activity analysis * Propagate usefulness and variedness through `store_borrow` instructions. * Fix a bug where standard usefulness does not get propagated if the instruction goes through the `mayReadFromMemory()` case. Remove special handling `mayReadFromMemory()` and propagate usefulness through all buffer operands instead. * Handle `destructure_tuple` instead of `tuple_extract` in array literal initialization pattern matcher. ### `autodiff_function` canonicalization * Rewrite argument cloning logic as `copyParameterArgumentsForApply`, used in `reapplyFunctionConversions` and the curry thunk cloning logic in `ADContext::promoteToDifferentiableFunction`. ### VJPEmitter * Each `@guaranteed` trampoline argument needs to have a lifetime-ending use past its destination argument's lifetime-ending uses, so we keep track of these pairs of arguments in `trampolinedGuaranteedPhiArguments` and emit `end_borrow`s when function cloning is finished. * Create trampoline blocks for `cond_br` instructions to conform to ownership rules. ### PullbackEmitter * Make pullback struct arguments have `@owned` ownership, for both function arguments and phi arguments. * Make all other phi arguments also have `@owned` ownership. * Note: Linear maps get evaluated linearly, so all values should get consumed immediately when a pullback is called. This is not the case yet since pullback functions have `@guaranteed` arguments. We should change this in the future so that all calls to pullbacks consume their arguments. ([TF-761](https://bugs.swift.org/browse/TF-761)) * Emit a `switch_enum` even if there is only one successor. It no longer triggers any verification failure. • Remove the [TF-585](https://bugs.swift.org/browse/TF-585) workaround (emitting a `fix_lifetime` on boxed enums) since the crasher in AllocBoxToStack is no longer reproducible. * Handle `destructure_tuple` instead of `tuple_extract` in array literal initialization adjoint emission logic. * Add pullback emission visitors for `destructure_tuple`, `load_borrow`, `store_borrow`, `copy_value`, and `begin_borrow`. Resolves [TF-709](https://bugs.swift.org/browse/TF-709) and [TF-585](https://bugs.swift.org/browse/TF-585). Also improves some source locations in diagnostics since the differentiation transform runs on a higher-level IR.
1 parent 3374bc7 commit 5dd9c50

10 files changed

+614
-415
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 553 additions & 344 deletions
Large diffs are not rendered by default.

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P) {
9393
P.addAllocBoxToStack();
9494
P.addNoReturnFolding();
9595
addDefiniteInitialization(P);
96+
97+
// SWIFT_ENABLE_TENSORFLOW
98+
P.addDifferentiation();
9699
// Only run semantic arc opts if we are optimizing and if mandatory semantic
97100
// arc opts is explicitly enabled.
98101
//
@@ -107,8 +110,6 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P) {
107110
}
108111
if (!Options.StripOwnershipAfterSerialization)
109112
P.addOwnershipModelEliminator();
110-
// SWIFT_ENABLE_TENSORFLOW
111-
P.addDifferentiation();
112113
P.addMandatoryInlining();
113114
P.addMandatorySILLinker();
114115

test/AutoDiff/anyderivative.swift

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ AnyDerivativeTests.test("Vector") {
2525
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
2626
tan += tan
2727
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
28-
expectEqual(tan, tan.allDifferentiableVariables)
2928
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan)
3029
expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan)
3130
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan))
@@ -37,7 +36,6 @@ AnyDerivativeTests.test("Generic") {
3736
let cotan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
3837
tan += tan
3938
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 2)), tan)
40-
expectEqual(tan, tan.allDifferentiableVariables)
4139
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan + tan)
4240
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 0)), tan - tan)
4341
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan.moved(along: tan))
@@ -51,7 +49,6 @@ AnyDerivativeTests.test("Zero") {
5149
expectEqual(zero, zero + zero)
5250
expectEqual(zero, zero - zero)
5351
expectEqual(zero, zero.moved(along: zero))
54-
expectEqual(zero, zero.allDifferentiableVariables)
5552

5653
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
5754
expectEqual(zero, zero)
@@ -125,11 +122,7 @@ AnyDerivativeTests.test("Derivatives") {
125122

126123
// Test `AnyDerivative` initializer.
127124
func typeErased<T>(_ x: T) -> AnyDerivative
128-
where T : Differentiable, T.TangentVector == T,
129-
T.AllDifferentiableVariables == T,
130-
// NOTE: The requirement below should be defined on `Differentiable`.
131-
// But it causes a crash due to generic signature minimization bug.
132-
T.TangentVector == T.TangentVector.AllDifferentiableVariables
125+
where T : Differentiable, T.TangentVector == T
133126
{
134127
let any = AnyDerivative(x)
135128
return any + any

test/AutoDiff/autodiff_function_silgen.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ func myfunction(_ f: @escaping @differentiable (Float) -> (Float)) -> (Float) ->
1414
return f
1515
}
1616

17+
var global_f: @differentiable (Float) -> Float = {$0}
18+
19+
func calls_global_f() {
20+
_ = global_f(10)
21+
}
22+
1723
func apply() {
1824
_ = myfunction(thin)
1925
}
@@ -46,7 +52,6 @@ func apply() {
4652
// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = autodiff_function [wrt 0] [order 1] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
4753

4854
// CHECK-SIL: [[DIFFED:%.*]] = autodiff_function [wrt 0] [order 1] {{%.*}} : $@callee_guaranteed (Float) -> Float
49-
// CHECK-SIL: release_value [[DIFFED]] : $@differentiable @callee_guaranteed (Float) -> Float
5055

5156
//===----------------------------------------------------------------------===//
5257
// Reabstraction

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ enum Tree : Differentiable & AdditiveArithmetic {
113113

114114
// expected-error @+1 {{function is not differentiable}}
115115
@differentiable
116-
// expected-note @+2 {{when differentiating this function definition}}
117-
// expected-note @+1 {{differentiating enum values is not yet supported}}
116+
// expected-note @+1 {{when differentiating this function definition}}
118117
static func +(_ lhs: Self, _ rhs: Self) -> Self {
119118
switch (lhs, rhs) {
119+
// expected-note @+1 {{differentiating enum values is not yet supported}}
120120
case let (.leaf(x), .leaf(y)):
121121
return .leaf(x + y)
122122
case let (.branch(x1, x2), .branch(y1, y2)):
@@ -128,10 +128,10 @@ enum Tree : Differentiable & AdditiveArithmetic {
128128

129129
// expected-error @+1 {{function is not differentiable}}
130130
@differentiable
131-
// expected-note @+2 {{when differentiating this function definition}}
132-
// expected-note @+1 {{differentiating enum values is not yet supported}}
131+
// expected-note @+1 {{when differentiating this function definition}}
133132
static func -(_ lhs: Self, _ rhs: Self) -> Self {
134133
switch (lhs, rhs) {
134+
// expected-note @+1 {{differentiating enum values is not yet supported}}
135135
case let (.leaf(x), .leaf(y)):
136136
return .leaf(x - y)
137137
case let (.branch(x1, x2), .branch(y1, y2)):

test/AutoDiff/control_flow_sil.swift

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,63 +42,58 @@ func cond(_ x: Float) -> Float {
4242
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0)
4343
// CHECK-DATA-STRUCTURES: }
4444

45-
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
45+
46+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
4647
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
4748
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
49+
// CHECK-SIL: cond_br {{%.*}}, bb1, bb3
50+
51+
// CHECK-SIL: bb1:
4852
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
49-
// CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
50-
// CHECK-SIL: cond_br {{%.*}}, bb1([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0), bb2([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_
53+
// CHECK-SIL: br bb2([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0)
5154

52-
// CHECK-SIL: bb1([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0)
55+
// CHECK-SIL: bb2([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0)
5356
// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0
5457
// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]]
55-
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
58+
// CHECK-SIL: br bb5({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
59+
60+
// CHECK-SIL: bb3:
61+
// CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
62+
// CHECK-SIL: br bb4([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_0)
5663

57-
// CHECK-SIL: bb2([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0)
64+
// CHECK-SIL: bb4([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0)
5865
// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0
5966
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]]
60-
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
67+
// CHECK-SIL: br bb5({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
6168

62-
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
69+
// CHECK-SIL: bb5([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
6370
// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0
6471
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__cond__pullback_src_0_wrt_0
6572
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]])
6673
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
6774
// CHECK-SIL: return [[VJP_RESULT]]
6875

69-
// CHECK-SIL-LABEL: sil hidden @AD__cond__pullback_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
70-
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_bb3__PB__src_0_wrt_0):
71-
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_bb3__PB__src_0_wrt_0, #_AD__cond_bb3__PB__src_0_wrt_0.predecessor
76+
77+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
78+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : @owned $_AD__cond_bb3__PB__src_0_wrt_0):
79+
// CHECK-SIL: [[BB3_PRED:%.*]] = destructure_struct [[BB3_PB_STRUCT]] : $_AD__cond_bb3__PB__src_0_wrt_0
7280
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
7381

74-
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
82+
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : @owned $_AD__cond_bb1__PB__src_0_wrt_0):
7583
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0)
7684

77-
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
78-
// CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]]
85+
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : @owned $_AD__cond_bb1__PB__src_0_wrt_0):
86+
// CHECK-SIL: ([[BB1_PRED:%.*]], [[BB1_PB:%.*]]) = destructure_struct [[BB1_PB_STRUCT]]
7987
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
80-
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
81-
// CHECK-SIL: release_value {{%.*}} : $Float
82-
// CHECK-SIL: release_value {{%.*}} : $Float
83-
// CHECK-SIL: release_value {{%.*}} : $Float
84-
// CHECK-SIL: release_value {{%.*}} : $Float
85-
// CHECK-SIL: release_value {{%.*}} : $Float
86-
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
87-
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
88-
89-
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
88+
// CHECK-SIL: switch_enum [[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0, case #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1: bb5
89+
90+
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $_AD__cond_bb2__PB__src_0_wrt_0):
9091
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0)
9192

92-
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
93-
// CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]]
93+
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : @owned $_AD__cond_bb2__PB__src_0_wrt_0):
94+
// CHECK-SIL: ([[BB2_PRED:%.*]], [[BB2_PB:%.*]]) = destructure_struct [[BB2_PB_STRUCT]]
9495
// CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
95-
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
96-
// CHECK-SIL: release_value {{%.*}} : $Float
97-
// CHECK-SIL: release_value {{%.*}} : $Float
98-
// CHECK-SIL: release_value {{%.*}} : $Float
99-
// CHECK-SIL: release_value {{%.*}} : $Float
100-
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
101-
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
96+
// CHECK-SIL: switch_enum [[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_0, case #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1: bb6
10297

10398
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
10499
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0)
@@ -107,8 +102,6 @@ func cond(_ x: Float) -> Float {
107102
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0)
108103

109104
// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
110-
// CHECK-SIL: release_value {{%.*}} : $Float
111-
// CHECK-SIL: release_value {{%.*}} : $Float
112105
// CHECK-SIL: return {{%.*}} : $Float
113106

114107
@differentiable
@@ -164,9 +157,10 @@ func cond_tuple_var(_ x: Float) -> Float {
164157
}
165158
return y.1
166159
}
167-
// CHECK-SIL-LABEL: sil hidden @AD__cond_tuple_var__pullback_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
160+
161+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond_tuple_var__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
168162
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0):
169-
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0, #_AD__cond_tuple_var_bb3__PB__src_0_wrt_0.predecessor
163+
// CHECK-SIL: [[BB3_PRED:%.*]] = destructure_struct [[BB3_PB_STRUCT]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0
170164
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
171165
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
172166
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
@@ -175,21 +169,19 @@ func cond_tuple_var(_ x: Float) -> Float {
175169
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0)
176170

177171
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
178-
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
172+
// CHECK-SIL: [[BB1_PRED:%.*]] = destructure_struct [[BB1_PB_STRUCT]]
179173
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
180174
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
181-
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
182-
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
175+
// CHECK-SIL: switch_enum [[BB1_PRED]] : $_AD__cond_tuple_var_bb1__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb1__Pred__src_0_wrt_0.bb0!enumelt.1: bb5 // id: %81
183176

184177
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
185178
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0)
186179

187180
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
188-
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
181+
// CHECK-SIL: [[BB2_PRED:%.*]] = destructure_struct [[BB2_PB_STRUCT]]
189182
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
190183
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
191-
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
192-
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
184+
// CHECK-SIL: switch_enum [[BB2_PRED]] : $_AD__cond_tuple_var_bb2__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb2__Pred__src_0_wrt_0.bb0!enumelt.1: bb6
193185

194186
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
195187
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)

test/AutoDiff/refcounting.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,16 @@ _ = pullback(at: Vector.zero, in: testOwnedVector)
4848
// CHECK: bb0([[SEED:%.*]] : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__PB__src_0_wrt_0_1):
4949
// CHECK: [[PB:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}UsesMethodOfNoDerivativeMember{{.*}}applied2to{{.*}}__PB__src_0_wrt_0_1
5050
// CHECK: [[NEEDED_COTAN:%.*]] = apply [[PB]]([[SEED]]) : $@callee_guaranteed (@guaranteed Vector) -> @owned Vector
51-
// CHECK: release_value [[SEED:%.*]] : $Vector
5251

53-
// CHECK-LABEL sil hidden @{{.*}}subset_pullback_releases_unused_ones{{.*}}__pullback_src_0_wrt_0
52+
// CHECK-LABEL: sil hidden @{{.*}}subset_pullback_releases_unused_ones{{.*}}__pullback_src_0_wrt_0
5453
// CHECK: bb0([[SEED:%.*]] : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0):
54+
// CHECK: [[PB1:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0, #{{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0.pullback_0
5555
// CHECK: [[PB0:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}, #{{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0.pullback_1
5656
// CHECK: [[NEEDED_COTAN0:%.*]] = apply [[PB0]]([[SEED]]) : $@callee_guaranteed (@guaranteed Vector) -> @owned Vector
57+
// CHECK: strong_release [[PB0]]
5758
// CHECK-NOT: release_value [[NEEDED_COTAN0]] : $Vector
58-
// CHECK: [[PB1:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0, #{{.*}}subset_pullback_releases_unused_ones{{.*}}__PB__src_0_wrt_0.pullback_0
5959
// CHECK: [[NEEDED_COTAN1:%.*]] = apply [[PB1]]([[NEEDED_COTAN0]]) : $@callee_guaranteed (@guaranteed Vector) -> @owned Vector
60+
// CHECK: strong_release [[PB1]]
6061
// CHECK: retain_value [[NEEDED_COTAN1]] : $Vector
6162
// CHECK: release_value [[NEEDED_COTAN0]] : $Vector
6263
// CHECK: release_value [[NEEDED_COTAN1]] : $Vector
@@ -96,8 +97,8 @@ _ = pullback(at: Vector.zero, in: testOwnedVector)
9697
// CHECK-LABEL: @{{.*}}testOwnedVector{{.*}}__pullback_src_0_wrt_0
9798
// CHECK: bb0({{%.*}} : $Vector, [[PB_STRUCT:%.*]] : ${{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0):
9899
// CHECK: [[PULLBACK0:%.*]] = struct_extract [[PB_STRUCT]] : ${{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0, #{{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0.pullback_0
99-
// CHECK-NOT: release_value [[PULLBACK0]]
100-
// CHECK-NOT: release_value [[PB_STRUCT]]
100+
// CHECK-NOT: release_value [[PULLBACK0]] : @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)
101+
// CHECK-NOT: release_value [[PB_STRUCT]] : ${{.*}}testOwnedVector{{.*}}__PB__src_0_wrt_0
101102
// CHECK: }
102103

103104
func side_effect_release_zero(_ x: Vector) -> Vector {

test/AutoDiff/simple_math.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ SimpleMathTests.test("GlobalDiffableFunc") {
7979
expectEqual(2, gradient(at: 1, in: foo_diffable))
8080
expectEqual(2, gradient(at: 1, in: { x in foo_diffable(x) }))
8181
expectEqual(1, gradient(at: 1, in: { (x: Float) -> Float in
82-
foo_diffable = { x in x + 1 };
82+
foo_diffable = { x in x + 1 }
8383
return foo_diffable(x)
8484
}))
8585
expectEqual(1, gradient(at: 1, in: foo_diffable))

0 commit comments

Comments
 (0)