Skip to content

Commit 4b7baf0

Browse files
authored
[AutoDiff] Properly collect inout parameter adjoints (#41559)
Apparently, the parameter index calculation in #37861 was not always correct in presence of other pullback parameters (e.g. captures and non-differentiated args). Collect all inout parameters and collect inout parameter adjoints correctly. Resolves SR-15891
1 parent c9e3699 commit 4b7baf0

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,37 +2085,38 @@ bool PullbackCloner::Implementation::run() {
20852085

20862086
// Collect differentiation parameter adjoints.
20872087
// Do a first pass to collect non-inout values.
2088-
unsigned pullbackInoutArgumentIndex = 0;
20892088
for (auto i : getConfig().parameterIndices->getIndices()) {
2090-
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2091-
if (!isParameterInout) {
2089+
if (!conv.getParameters()[i].isIndirectMutating()) {
2090+
addRetElt(i);
2091+
}
2092+
}
2093+
2094+
// Do a second pass for all inout parameters, however this is only necessary
2095+
// for functions with multiple basic blocks. For functions with a single
2096+
// basic block adjoint accumulation for those parameters is already done by
2097+
// per-instruction visitors.
2098+
if (getOriginal().size() > 1) {
2099+
const auto &pullbackConv = pullback.getConventions();
2100+
SmallVector<SILArgument *, 1> pullbackInOutArgs;
2101+
for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) {
2102+
if (pullbackConv.getParameters()[pullbackArg.index()].isIndirectMutating())
2103+
pullbackInOutArgs.push_back(pullbackArg.value());
2104+
}
2105+
2106+
unsigned pullbackInoutArgumentIdx = 0;
2107+
for (auto i : getConfig().parameterIndices->getIndices()) {
2108+
// Skip non-inout parameters.
2109+
if (!conv.getParameters()[i].isIndirectMutating())
2110+
continue;
2111+
2112+
// For functions with multiple basic blocks, accumulation is needed
2113+
// for `inout` parameters because pullback basic blocks have different
2114+
// adjoint buffers.
2115+
pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]);
20922116
addRetElt(i);
20932117
}
20942118
}
20952119

2096-
// Do a second pass for all inout parameters.
2097-
for (auto i : getConfig().parameterIndices->getIndices()) {
2098-
// Skip non-inout parameters.
2099-
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2100-
if (!isParameterInout)
2101-
continue;
2102-
2103-
// Skip `inout` parameters for functions with a single basic block:
2104-
// adjoint accumulation for those parameters is already done by
2105-
// per-instruction visitors.
2106-
if (getOriginal().size() == 1)
2107-
continue;
2108-
2109-
// For functions with multiple basic blocks, accumulation is needed
2110-
// for `inout` parameters because pullback basic blocks have different
2111-
// adjoint buffers.
2112-
auto pullbackInoutArgument =
2113-
getPullback()
2114-
.getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++];
2115-
pullbackIndirectResults.push_back(pullbackInoutArgument);
2116-
addRetElt(i);
2117-
}
2118-
21192120
// Copy them to adjoint indirect results.
21202121
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
21212122
"Indirect parameter adjoint count mismatch");
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// SR-15891: The parameter indices used for copying
4+
// inout tangent vectors were calculated improperly
5+
// in presence of other pullback parameters (e.g.
6+
// captures)
7+
8+
import _Differentiation
9+
10+
struct Foo {
11+
var bar : Float
12+
var baz : Float
13+
var name : String?
14+
}
15+
16+
func outerFunc(doIterations : Int, value: inout Float) -> (Float, (Float) -> Float) {
17+
@differentiable(reverse, wrt: param)
18+
func innerFunc1(param: Float, other: Foo) -> Float {
19+
value += param * other.bar
20+
return value * param * 2.0
21+
}
22+
23+
@differentiable(reverse, wrt: param1)
24+
func loop(param1 : Float, other1: Foo) -> Float {
25+
var res : Float;
26+
res = 0.0
27+
if (doIterations > 0) {
28+
res = innerFunc1(param: param1, other: other1)
29+
}
30+
31+
return res
32+
}
33+
34+
@differentiable(reverse)
35+
func curriedFunc(param: Float) -> Float {
36+
let other = Foo(bar: 7, baz: 9)
37+
return loop(param1: param, other1: other)
38+
}
39+
40+
let valAndPullback = valueWithPullback(at: value, of: curriedFunc)
41+
return (value + valAndPullback.value, valAndPullback.pullback)
42+
}

0 commit comments

Comments
 (0)