Skip to content

Commit 3277457

Browse files
committed
Incorporated Dan's cleanup suggestions.
1 parent e4e544d commit 3277457

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,25 +2093,27 @@ bool PullbackCloner::Implementation::run() {
20932093
}
20942094
}
20952095

2096-
// Do a second pass for all inout values.
2096+
// Do a second pass for all inout parameters.
20972097
for (auto i : getConfig().parameterIndices->getIndices()) {
2098-
// Skip `inout` parameters for functions with single basic blocks:
2099-
// additional adjoint accumulation is not necessary.
2098+
// Skip non-inout parameters.
2099+
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2100+
if (!isParameterInout)
2101+
continue;
2102+
2103+
// Skip `inout` parameters for functions with a single basic block:
2104+
// adjoint accumulation for those parameters is already done by
2105+
// per-instruction visitors.
2106+
if (getOriginal().size() == 1)
2107+
continue;
21002108

21012109
// For functions with multiple basic blocks, accumulation is needed
21022110
// for `inout` parameters because pullback basic blocks have different
21032111
// adjoint buffers.
2104-
auto isOriginalFunctionSingleBasicBlock = getOriginal().size() == 1;
2105-
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2106-
if (isParameterInout) {
2107-
if (isOriginalFunctionSingleBasicBlock) {
2108-
continue;
2109-
} else {
2110-
pullbackIndirectResults.push_back(
2111-
getPullback().getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++]);
2112-
}
2113-
addRetElt(i);
2114-
}
2112+
auto pullbackInoutArgument =
2113+
getPullback()
2114+
.getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++];
2115+
pullbackIndirectResults.push_back(pullbackInoutArgument);
2116+
addRetElt(i);
21152117
}
21162118

21172119
// Copy them to adjoint indirect results.
@@ -2126,9 +2128,9 @@ bool PullbackCloner::Implementation::run() {
21262128
} else {
21272129
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
21282130
}
2131+
currentIndex++;
21292132
// Prevent source buffer from being deallocated, since the underlying
21302133
// value is moved.
2131-
currentIndex++;
21322134
destroyedLocalAllocations.insert(source);
21332135
}
21342136

test/AutoDiff/validation-test/inout_control_flow.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ struct Model: Differentiable {
1111
var first: Float = 3
1212
var second: Float = 1
1313

14-
mutating func outer(){
14+
mutating func outer() {
1515
inner()
1616
}
1717

1818
mutating func inner() {
1919
self.second = self.first
2020

21+
// Dummy no-op if block, required to introduce control flow.
2122
let x = 5
2223
if x < 50 {}
2324
}
@@ -53,6 +54,7 @@ struct Model2<T: NumericDifferentiable>: Differentiable {
5354
func adjust<T: NumericDifferentiable>(model: inout Model2<T>, multiplier: T) {
5455
model.first = model.second * multiplier
5556

57+
// Dummy no-op if block, required to introduce control flow.
5658
let x = 5
5759
if x < 50 {}
5860
}
@@ -75,6 +77,7 @@ InoutControlFlowTests.test("InoutParameterWithControlFlow") {
7577
func adjust2<T: NumericDifferentiable>(multiplier: T, model: inout Model2<T>) {
7678
model.first = model.second * multiplier
7779

80+
// Dummy no-op if block, required to introduce control flow.
7881
let x = 5
7982
if x < 50 {}
8083
}

0 commit comments

Comments
 (0)