Skip to content

Commit 0f884b8

Browse files
authored
[AutoDiff] [SR-14218] Correctly propagate tangent vectors of inout parameters from functions with multiple basic blocks (#37861)
* Attempt at a patch for SR-14053 and SR-14218, based on Dan Zheng's initial fix. * Incorporated Dan's cleanup suggestions. * Converting a bool SmallVector to a SmallBitVector. * Testing if Windows issues with this test are due to runtime support. * Simplifying test case.
1 parent e547c06 commit 0f884b8

File tree

2 files changed

+132
-5
lines changed

2 files changed

+132
-5
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,8 @@ bool PullbackCloner::Implementation::run() {
20522052
SmallVector<SILValue, 8> retElts;
20532053
// This vector will contain all indirect parameter adjoint buffers.
20542054
SmallVector<SILValue, 4> indParamAdjoints;
2055+
// This vector will identify the locations where initialization is needed.
2056+
SmallBitVector outputsToInitialize;
20552057

20562058
auto conv = getOriginal().getConventions();
20572059
auto origParams = getOriginal().getArgumentsWithoutIndirectResults();
@@ -2071,25 +2073,62 @@ bool PullbackCloner::Implementation::run() {
20712073
case SILValueCategory::Address: {
20722074
auto adjBuf = getAdjointBuffer(origEntry, origParam);
20732075
indParamAdjoints.push_back(adjBuf);
2076+
outputsToInitialize.push_back(
2077+
!conv.getParameters()[parameterIndex].isIndirectMutating());
20742078
break;
20752079
}
20762080
}
20772081
};
2082+
SmallVector<SILArgument *, 4> pullbackIndirectResults(
2083+
getPullback().getIndirectResults().begin(),
2084+
getPullback().getIndirectResults().end());
2085+
20782086
// Collect differentiation parameter adjoints.
2087+
// Do a first pass to collect non-inout values.
2088+
unsigned pullbackInoutArgumentIndex = 0;
2089+
for (auto i : getConfig().parameterIndices->getIndices()) {
2090+
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2091+
if (!isParameterInout) {
2092+
addRetElt(i);
2093+
}
2094+
}
2095+
2096+
// Do a second pass for all inout parameters.
20792097
for (auto i : getConfig().parameterIndices->getIndices()) {
2080-
// Skip `inout` parameters.
2081-
if (conv.getParameters()[i].isIndirectMutating())
2098+
// Skip non-inout parameters.
2099+
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2100+
if (!isParameterInout)
20822101
continue;
2102+
2103+
// Skip `inout` parameters for functions with a single basic block:
2104+
// adjoint accumulation for those parameters is already done by
2105+
// per-instruction visitors.
2106+
if (getOriginal().size() == 1)
2107+
continue;
2108+
2109+
// For functions with multiple basic blocks, accumulation is needed
2110+
// for `inout` parameters because pullback basic blocks have different
2111+
// adjoint buffers.
2112+
auto pullbackInoutArgument =
2113+
getPullback()
2114+
.getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++];
2115+
pullbackIndirectResults.push_back(pullbackInoutArgument);
20832116
addRetElt(i);
20842117
}
20852118

20862119
// Copy them to adjoint indirect results.
2087-
assert(indParamAdjoints.size() == getPullback().getIndirectResults().size() &&
2120+
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
20882121
"Indirect parameter adjoint count mismatch");
2089-
for (auto pair : zip(indParamAdjoints, getPullback().getIndirectResults())) {
2122+
unsigned currentIndex = 0;
2123+
for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) {
20902124
auto source = std::get<0>(pair);
20912125
auto *dest = std::get<1>(pair);
2092-
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
2126+
if (outputsToInitialize[currentIndex]) {
2127+
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
2128+
} else {
2129+
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
2130+
}
2131+
currentIndex++;
20932132
// Prevent source buffer from being deallocated, since the underlying
20942133
// value is moved.
20952134
destroyedLocalAllocations.insert(source);
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import _Differentiation
6+
7+
var InoutControlFlowTests = TestSuite("InoutControlFlow")
8+
9+
// SR-14218
10+
struct Model: Differentiable {
11+
var first: Float = 3
12+
var second: Float = 1
13+
14+
mutating func outer() {
15+
inner()
16+
}
17+
18+
mutating func inner() {
19+
self.second = self.first
20+
21+
// Dummy no-op if block, required to introduce control flow.
22+
let x = 5
23+
if x < 50 {}
24+
}
25+
}
26+
27+
@differentiable(reverse)
28+
func loss(model: Model) -> Float{
29+
var model = model
30+
model.outer()
31+
return model.second
32+
}
33+
34+
InoutControlFlowTests.test("MutatingBeforeControlFlow") {
35+
var model = Model()
36+
let grad = gradient(at: model, of: loss)
37+
expectEqual(1, grad.first)
38+
expectEqual(0, grad.second)
39+
}
40+
41+
// SR-14053
42+
@differentiable(reverse)
43+
func adjust(model: inout Model, multiplier: Float) {
44+
model.first = model.second * multiplier
45+
46+
// Dummy no-op if block, required to introduce control flow.
47+
let x = 5
48+
if x < 50 {}
49+
}
50+
51+
@differentiable(reverse)
52+
func loss2(model: Model, multiplier: Float) -> Float {
53+
var model = model
54+
adjust(model: &model, multiplier: multiplier)
55+
return model.first
56+
}
57+
58+
InoutControlFlowTests.test("InoutParameterWithControlFlow") {
59+
var model = Model(first: 1, second: 3)
60+
let grad = gradient(at: model, 5.0, of: loss2)
61+
expectEqual(0, grad.0.first)
62+
expectEqual(5, grad.0.second)
63+
}
64+
65+
@differentiable(reverse)
66+
func adjust2(multiplier: Float, model: inout Model) {
67+
model.first = model.second * multiplier
68+
69+
// Dummy no-op if block, required to introduce control flow.
70+
let x = 5
71+
if x < 50 {}
72+
}
73+
74+
@differentiable(reverse)
75+
func loss3(model: Model, multiplier: Float) -> Float {
76+
var model = model
77+
adjust2(multiplier: multiplier, model: &model)
78+
return model.first
79+
}
80+
81+
InoutControlFlowTests.test("LaterInoutParameterWithControlFlow") {
82+
var model = Model(first: 1, second: 3)
83+
let grad = gradient(at: model, 5.0, of: loss3)
84+
expectEqual(0, grad.0.first)
85+
expectEqual(5, grad.0.second)
86+
}
87+
88+
runAllTests()

0 commit comments

Comments
 (0)