Skip to content

Commit 488dbff

Browse files
committed
[AutoDiff] Adjoint buffer optimization for address projections.
Do not allocate adjoint buffers for address projections; they become projections into their adjoint base buffer.
1 parent 2d9111d commit 488dbff

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4141,6 +4141,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41414141
auto addActiveValue = [&](SILValue v) {
41424142
if (visited.count(v))
41434143
return;
4144+
// Skip address projections.
4145+
// Address projections do not need their own adjoint buffers; they
4146+
// become projections into their adjoint base buffer.
4147+
if (Projection::isAddressProjection(v))
4148+
return;
41444149
visited.insert(v);
41454150
bbActiveValues.push_back(v);
41464151
};

test/AutoDiff/control_flow.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ ControlFlowTests.test("Conditionals") {
6868
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple))
6969
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple))
7070

71+
func cond_tuple2(_ x: Float) -> Float {
72+
// Convoluted function returning `x + x`.
73+
let y: (Float, Float) = (x, x)
74+
let y0 = y.0
75+
if x > 0 {
76+
let y1 = y.1
77+
return y0 + y1
78+
}
79+
let y0_double = y0 + y.0
80+
let y1 = y.1
81+
return y0_double - y1 + y.0
82+
}
83+
expectEqual((8, 2), valueWithGradient(at: 4, in: cond_tuple2))
84+
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_tuple2))
85+
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_tuple2))
86+
7187
func cond_tuple_var(_ x: Float) -> Float {
7288
// Convoluted function returning `x + x`.
7389
var y: (Float, Float) = (x, x)
@@ -135,6 +151,22 @@ ControlFlowTests.test("Conditionals") {
135151
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_struct))
136152
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_struct))
137153

154+
func cond_struct2(_ x: Float) -> Float {
155+
// Convoluted function returning `x + x`.
156+
let y = FloatPair(x, x)
157+
let y0 = y.first
158+
if x > 0 {
159+
let y1 = y.second
160+
return y0 + y1
161+
}
162+
let y0_double = y0 + y.first
163+
let y1 = y.second
164+
return y0_double - y1 + y.first
165+
}
166+
expectEqual((8, 2), valueWithGradient(at: 4, in: cond_struct2))
167+
expectEqual((-20, 2), valueWithGradient(at: -10, in: cond_struct2))
168+
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: cond_struct2))
169+
138170
func cond_struct_var(_ x: Float) -> Float {
139171
// Convoluted function returning `x + x`.
140172
var y = FloatPair(x, x)

0 commit comments

Comments
 (0)