Skip to content

Commit 9b36197

Browse files
committed
Add control flow + address projection adjoint tests.
- Use deterministic iteration order when processing `@differentiable` attributes in differentiation transform. - Add adjoint SIL tests.
1 parent 488dbff commit 9b36197

File tree

2 files changed

+100
-11
lines changed

2 files changed

+100
-11
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6357,17 +6357,21 @@ void Differentiation::run() {
63576357
// A global differentiation context.
63586358
ADContext context(*this);
63596359

6360-
// Handle all the instructions and attributes in the module that trigger
6361-
// differentiation.
6360+
// Use a temporary list to store `@differentiable` attributes and invokers
6361+
// for deterministic iteration order.
6362+
SmallVector<std::pair<SILDifferentiableAttr *, DifferentiationInvoker>, 32>
6363+
invokerList;
6364+
6365+
// Register all `@differentiable` attributes and `autodiff_function`
6366+
// instructions in the module that trigger differentiation.
63626367
for (SILFunction &f : module) {
6363-
// If `f` has a `[differentiable]` attribute, register `f` and the attribute
6364-
// with an invoker.
63656368
for (auto *diffAttr : f.getDifferentiableAttrs()) {
63666369
DifferentiationInvoker invoker(diffAttr);
63676370
auto insertion =
63686371
context.getInvokers().try_emplace(diffAttr, invoker);
63696372
assert(insertion.second &&
63706373
"[differentiable] attribute already has an invoker");
6374+
invokerList.push_back({diffAttr, invoker});
63716375
continue;
63726376
}
63736377
for (SILBasicBlock &bb : f)
@@ -6392,10 +6396,10 @@ void Differentiation::run() {
63926396
bool errorOccurred = false;
63936397

63946398
// Process all `[differentiable]` attributes.
6395-
for (auto invokerInfo : context.getInvokers()) {
6396-
auto *attr = invokerInfo.first;
6399+
for (auto invokerPair : invokerList) {
6400+
auto *attr = invokerPair.first;
63976401
auto *original = attr->getOriginal();
6398-
auto invoker = invokerInfo.second;
6402+
auto invoker = invokerPair.second;
63996403
errorOccurred |=
64006404
context.processDifferentiableAttribute(original, attr, invoker);
64016405
}

test/AutoDiff/control_flow_sil.swift

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
2-
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
2+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
33

4-
// TODO: Add adjoint SIL FileCheck tests.
4+
// TODO: Add FileCheck tests.
55

6-
// Test conditional: a simple if-diamond.
6+
//===----------------------------------------------------------------------===//
7+
// Conditionals
8+
//===----------------------------------------------------------------------===//
79

810
@differentiable
911
@_silgen_name("cond")
@@ -40,7 +42,7 @@ func cond(_ x: Float) -> Float {
4042
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
4143
// CHECK-DATA-STRUCTURES: }
4244

43-
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0
45+
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
4446
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
4547
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
4648
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
@@ -64,6 +66,40 @@ func cond(_ x: Float) -> Float {
6466
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
6567
// CHECK-SIL: return [[VJP_RESULT]]
6668

69+
// CHECK-SIL-LABEL: sil hidden @AD__cond__adjoint_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
70+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_bb3__PB__src_0_wrt_0):
71+
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_bb3__PB__src_0_wrt_0, #_AD__cond_bb3__PB__src_0_wrt_0.predecessor
72+
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
73+
74+
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
75+
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
76+
77+
// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
78+
// CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]]
79+
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
80+
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
81+
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
82+
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
83+
84+
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
85+
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
86+
87+
// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
88+
// CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]]
89+
// CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
90+
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
91+
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
92+
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
93+
94+
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
95+
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
96+
97+
// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
98+
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
99+
100+
// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
101+
// CHECK-SIL: return {{%.*}} : $Float
102+
67103
@differentiable
68104
@_silgen_name("nested_cond")
69105
func nested_cond(_ x: Float, _ y: Float) -> Float {
@@ -89,3 +125,52 @@ func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) ->
89125
}
90126
return y
91127
}
128+
129+
// Test control flow + tuple buffer.
130+
// Verify that adjoint buffers are not allocated for address projections.
131+
132+
@differentiable
133+
@_silgen_name("cond_tuple_var")
134+
func cond_tuple_var(_ x: Float) -> Float {
135+
// expected-warning @+1 {{variable 'y' was never mutated; consider changing to 'let' constant}}
136+
var y = (x, x)
137+
if x > 0 {
138+
return y.0
139+
}
140+
return y.1
141+
}
142+
// CHECK-SIL-LABEL: sil hidden @AD__cond_tuple_var__adjoint_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
143+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0):
144+
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0, #_AD__cond_tuple_var_bb3__PB__src_0_wrt_0.predecessor
145+
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
146+
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
147+
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
148+
149+
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
150+
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
151+
152+
// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
153+
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
154+
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
155+
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
156+
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
157+
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
158+
159+
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
160+
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
161+
162+
// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
163+
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
164+
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
165+
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
166+
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
167+
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
168+
169+
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
170+
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
171+
172+
// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
173+
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
174+
175+
// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
176+
// CHECK-SIL: return {{%.*}} : $Float

0 commit comments

Comments
 (0)