Skip to content

Commit 8c646af

Browse files
authored
[AutoDiff] Adjoint buffer optimization for address projections. (#25268)
- Do not allocate adjoint buffers for address projections; they become projections into their adjoint base buffer. - Use deterministic iteration order when processing `@differentiable` attributes in differentiation transform. - Add adjoint SIL tests.
1 parent 8d9df6c commit 8c646af

File tree

3 files changed

+141
-18
lines changed

3 files changed

+141
-18
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using namespace swift;
5454
using llvm::DenseMap;
5555
using llvm::SmallDenseMap;
5656
using llvm::SmallDenseSet;
57+
using llvm::SmallMapVector;
5758
using llvm::SmallSet;
5859

5960
/// This flag is used to disable `autodiff_function_extract` instruction folding
@@ -844,7 +845,9 @@ class ADContext {
844845
SmallPtrSet<AutoDiffFunctionInst *, 32> processedAutoDiffFunctionInsts;
845846

846847
/// Mapping from `[differentiable]` attributes to invokers.
847-
DenseMap<SILDifferentiableAttr *, DifferentiationInvoker> invokers;
848+
/// `SmallMapVector` is used for deterministic insertion order iteration.
849+
SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32>
850+
invokers;
848851

849852
/// Mapping from `autodiff_function` instructions to result indices.
850853
DenseMap<AutoDiffFunctionInst *, unsigned> resultIndices;
@@ -902,7 +905,8 @@ class ADContext {
902905
return processedAutoDiffFunctionInsts;
903906
}
904907

905-
DenseMap<SILDifferentiableAttr *, DifferentiationInvoker> &getInvokers() {
908+
llvm::SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32> &
909+
getInvokers() {
906910
return invokers;
907911
}
908912

@@ -957,8 +961,8 @@ class ADContext {
957961
}
958962

959963
void cleanUp() {
960-
for (auto invokerInfo : invokers) {
961-
auto *attr = invokerInfo.getFirst();
964+
for (auto invokerPair : invokers) {
965+
auto *attr = std::get<0>(invokerPair);
962966
auto *original = attr->getOriginal();
963967
LLVM_DEBUG(getADDebugStream()
964968
<< "Removing [differentiable] attribute for "
@@ -4141,6 +4145,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41414145
auto addActiveValue = [&](SILValue v) {
41424146
if (visited.count(v))
41434147
return;
4148+
// Skip address projections.
4149+
// Address projections do not need their own adjoint buffers; they
4150+
// become projections into their adjoint base buffer.
4151+
if (Projection::isAddressProjection(v))
4152+
return;
41444153
visited.insert(v);
41454154
bbActiveValues.push_back(v);
41464155
};
@@ -6352,17 +6361,14 @@ void Differentiation::run() {
63526361
// A global differentiation context.
63536362
ADContext context(*this);
63546363

6355-
// Handle all the instructions and attributes in the module that trigger
6356-
// differentiation.
6364+
// Register all `@differentiable` attributes and `autodiff_function`
6365+
// instructions in the module that trigger differentiation.
63576366
for (SILFunction &f : module) {
6358-
// If `f` has a `[differentiable]` attribute, register `f` and the attribute
6359-
// with an invoker.
63606367
for (auto *diffAttr : f.getDifferentiableAttrs()) {
63616368
DifferentiationInvoker invoker(diffAttr);
6362-
auto insertion =
6363-
context.getInvokers().try_emplace(diffAttr, invoker);
6364-
assert(insertion.second &&
6369+
assert(!context.getInvokers().count(diffAttr) &&
63656370
"[differentiable] attribute already has an invoker");
6371+
context.getInvokers().insert({diffAttr, invoker});
63666372
continue;
63676373
}
63686374
for (SILBasicBlock &bb : f)
@@ -6387,10 +6393,10 @@ void Differentiation::run() {
63876393
bool errorOccurred = false;
63886394

63896395
// Process all `[differentiable]` attributes.
6390-
for (auto invokerInfo : context.getInvokers()) {
6391-
auto *attr = invokerInfo.first;
6396+
for (auto invokerPair : context.getInvokers()) {
6397+
auto *attr = invokerPair.first;
63926398
auto *original = attr->getOriginal();
6393-
auto invoker = invokerInfo.second;
6399+
auto invoker = invokerPair.second;
63946400
errorOccurred |=
63956401
context.processDifferentiableAttribute(original, attr, invoker);
63966402
}

test/AutoDiff/control_flow.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ ControlFlowTests.test("Conditionals") {
6868
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple))
6969
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple))
7070

71+
func cond_tuple2(_ x: Float) -> Float {
72+
// Convoluted function returning `x + x`.
73+
let y: (Float, Float) = (x, x)
74+
let y0 = y.0
75+
if x > 0 {
76+
let y1 = y.1
77+
return y0 + y1
78+
}
79+
let y0_double = y0 + y.0
80+
let y1 = y.1
81+
return y0_double - y1 + y.0
82+
}
83+
expectEqual((8, 2), valueWithGradient(at: 4, in: cond_tuple2))
84+
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple2))
85+
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple2))
86+
7187
func cond_tuple_var(_ x: Float) -> Float {
7288
// Convoluted function returning `x + x`.
7389
var y: (Float, Float) = (x, x)
@@ -135,6 +151,22 @@ ControlFlowTests.test("Conditionals") {
135151
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_struct))
136152
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_struct))
137153

154+
func cond_struct2(_ x: Float) -> Float {
155+
// Convoluted function returning `x + x`.
156+
let y = FloatPair(x, x)
157+
let y0 = y.first
158+
if x > 0 {
159+
let y1 = y.second
160+
return y0 + y1
161+
}
162+
let y0_double = y0 + y.first
163+
let y1 = y.second
164+
return y0_double - y1 + y.first
165+
}
166+
expectEqual((8, 2), valueWithGradient(at: 4, in: cond_struct2))
167+
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_struct2))
168+
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_struct2))
169+
138170
func cond_struct_var(_ x: Float) -> Float {
139171
// Convoluted function returning `x + x`.
140172
var y = FloatPair(x, x)

test/AutoDiff/control_flow_sil.swift

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
2-
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
2+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
33

4-
// TODO: Add adjoint SIL FileCheck tests.
4+
// TODO: Add FileCheck tests.
55

6-
// Test conditional: a simple if-diamond.
6+
//===----------------------------------------------------------------------===//
7+
// Conditionals
8+
//===----------------------------------------------------------------------===//
79

810
@differentiable
911
@_silgen_name("cond")
@@ -40,7 +42,7 @@ func cond(_ x: Float) -> Float {
4042
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
4143
// CHECK-DATA-STRUCTURES: }
4244

43-
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0
45+
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
4446
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
4547
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
4648
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
@@ -64,6 +66,40 @@ func cond(_ x: Float) -> Float {
6466
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
6567
// CHECK-SIL: return [[VJP_RESULT]]
6668

69+
// CHECK-SIL-LABEL: sil hidden @AD__cond__adjoint_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
70+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_bb3__PB__src_0_wrt_0):
71+
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_bb3__PB__src_0_wrt_0, #_AD__cond_bb3__PB__src_0_wrt_0.predecessor
72+
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
73+
74+
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0):
75+
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
76+
77+
// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
78+
// CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]]
79+
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
80+
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
81+
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
82+
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
83+
84+
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0):
85+
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
86+
87+
// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
88+
// CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]]
89+
// CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
90+
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
91+
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
92+
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0)
93+
94+
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
95+
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
96+
97+
// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
98+
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
99+
100+
// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
101+
// CHECK-SIL: return {{%.*}} : $Float
102+
67103
@differentiable
68104
@_silgen_name("nested_cond")
69105
func nested_cond(_ x: Float, _ y: Float) -> Float {
@@ -89,3 +125,52 @@ func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) ->
89125
}
90126
return y
91127
}
128+
129+
// Test control flow + tuple buffer.
130+
// Verify that adjoint buffers are not allocated for address projections.
131+
132+
@differentiable
133+
@_silgen_name("cond_tuple_var")
134+
func cond_tuple_var(_ x: Float) -> Float {
135+
// expected-warning @+1 {{variable 'y' was never mutated; consider changing to 'let' constant}}
136+
var y = (x, x)
137+
if x > 0 {
138+
return y.0
139+
}
140+
return y.1
141+
}
142+
// CHECK-SIL-LABEL: sil hidden @AD__cond_tuple_var__adjoint_src_0_wrt_0 : $@convention(thin) (Float, @guaranteed _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
143+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0):
144+
// CHECK-SIL: [[BB3_PRED:%.*]] = struct_extract %1 : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0, #_AD__cond_tuple_var_bb3__PB__src_0_wrt_0.predecessor
145+
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
146+
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
147+
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
148+
149+
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
150+
// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
151+
152+
// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
153+
// CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]]
154+
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
155+
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
156+
// CHECK-SIL: [[BB1_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB1_PRED]]
157+
// CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
158+
159+
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
160+
// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float)
161+
162+
// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float):
163+
// CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]]
164+
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
165+
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
166+
// CHECK-SIL: [[BB2_PB_STRUCT_DATA:%.*]] = unchecked_enum_data [[BB2_PRED]]
167+
// CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
168+
169+
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
170+
// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
171+
172+
// CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
173+
// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float)
174+
175+
// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float):
176+
// CHECK-SIL: return {{%.*}} : $Float

0 commit comments

Comments
 (0)