Skip to content

Commit e4e544d

Browse files
committed
Attempt at a patch for SR-14053 and SR-14218, based on Dan Zheng's initial fix.
1 parent 8e656ee commit e4e544d

File tree

2 files changed

+140
-7
lines changed

2 files changed

+140
-7
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,8 @@ bool PullbackCloner::Implementation::run() {
20522052
SmallVector<SILValue, 8> retElts;
20532053
// This vector will contain all indirect parameter adjoint buffers.
20542054
SmallVector<SILValue, 4> indParamAdjoints;
2055+
// This vector will identify the locations where initialization is needed.
2056+
SmallVector<bool, 8> outputsToInitialize;
20552057

20562058
auto conv = getOriginal().getConventions();
20572059
auto origParams = getOriginal().getArgumentsWithoutIndirectResults();
@@ -2071,27 +2073,62 @@ bool PullbackCloner::Implementation::run() {
20712073
case SILValueCategory::Address: {
20722074
auto adjBuf = getAdjointBuffer(origEntry, origParam);
20732075
indParamAdjoints.push_back(adjBuf);
2076+
outputsToInitialize.push_back(
2077+
!conv.getParameters()[parameterIndex].isIndirectMutating());
20742078
break;
20752079
}
20762080
}
20772081
};
2082+
SmallVector<SILArgument *, 4> pullbackIndirectResults(
2083+
getPullback().getIndirectResults().begin(),
2084+
getPullback().getIndirectResults().end());
2085+
20782086
// Collect differentiation parameter adjoints.
2087+
// Do a first pass to collect non-inout values.
2088+
unsigned pullbackInoutArgumentIndex = 0;
20792089
for (auto i : getConfig().parameterIndices->getIndices()) {
2080-
// Skip `inout` parameters.
2081-
if (conv.getParameters()[i].isIndirectMutating())
2082-
continue;
2083-
addRetElt(i);
2090+
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2091+
if (!isParameterInout) {
2092+
addRetElt(i);
2093+
}
2094+
}
2095+
2096+
// Do a second pass for all inout values.
2097+
for (auto i : getConfig().parameterIndices->getIndices()) {
2098+
// Skip `inout` parameters for functions with single basic blocks:
2099+
// additional adjoint accumulation is not necessary.
2100+
2101+
// For functions with multiple basic blocks, accumulation is needed
2102+
// for `inout` parameters because pullback basic blocks have different
2103+
// adjoint buffers.
2104+
auto isOriginalFunctionSingleBasicBlock = getOriginal().size() == 1;
2105+
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
2106+
if (isParameterInout) {
2107+
if (isOriginalFunctionSingleBasicBlock) {
2108+
continue;
2109+
} else {
2110+
pullbackIndirectResults.push_back(
2111+
getPullback().getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++]);
2112+
}
2113+
addRetElt(i);
2114+
}
20842115
}
20852116

20862117
// Copy them to adjoint indirect results.
2087-
assert(indParamAdjoints.size() == getPullback().getIndirectResults().size() &&
2118+
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
20882119
"Indirect parameter adjoint count mismatch");
2089-
for (auto pair : zip(indParamAdjoints, getPullback().getIndirectResults())) {
2120+
unsigned currentIndex = 0;
2121+
for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) {
20902122
auto source = std::get<0>(pair);
20912123
auto *dest = std::get<1>(pair);
2092-
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
2124+
if (outputsToInitialize[currentIndex]) {
2125+
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
2126+
} else {
2127+
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
2128+
}
20932129
// Prevent source buffer from being deallocated, since the underlying
20942130
// value is moved.
2131+
currentIndex++;
20952132
destroyedLocalAllocations.insert(source);
20962133
}
20972134

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import _Differentiation
6+
7+
var InoutControlFlowTests = TestSuite("InoutControlFlow")
8+
9+
// SR-14218
10+
struct Model: Differentiable {
11+
var first: Float = 3
12+
var second: Float = 1
13+
14+
mutating func outer(){
15+
inner()
16+
}
17+
18+
mutating func inner() {
19+
self.second = self.first
20+
21+
let x = 5
22+
if x < 50 {}
23+
}
24+
}
25+
26+
@differentiable(reverse)
27+
func loss(model: Model) -> Float{
28+
var model = model
29+
model.outer()
30+
return model.second
31+
}
32+
33+
InoutControlFlowTests.test("MutatingBeforeControlFlow") {
34+
var model = Model()
35+
let grad = gradient(at: model, of: loss)
36+
expectEqual(1, grad.first)
37+
expectEqual(0, grad.second)
38+
}
39+
40+
// SR-14053
41+
protocol NumericDifferentiable : Numeric, Differentiable {
42+
@differentiable(reverse) static func *(lhs: Self, rhs: Self) -> Self
43+
}
44+
45+
extension Float: NumericDifferentiable {}
46+
47+
struct Model2<T: NumericDifferentiable>: Differentiable {
48+
var first: T
49+
var second: T
50+
}
51+
52+
@differentiable(reverse)
53+
func adjust<T: NumericDifferentiable>(model: inout Model2<T>, multiplier: T) {
54+
model.first = model.second * multiplier
55+
56+
let x = 5
57+
if x < 50 {}
58+
}
59+
60+
@differentiable(reverse)
61+
func loss2(model: Model2<Float>, multiplier: Float) -> Float {
62+
var model = model
63+
adjust(model: &model, multiplier: multiplier)
64+
return model.first
65+
}
66+
67+
InoutControlFlowTests.test("InoutParameterWithControlFlow") {
68+
var model = Model2<Float>(first: 1, second: 3)
69+
let grad = gradient(at: model, 5.0, of: loss2)
70+
expectEqual(0, grad.0.first)
71+
expectEqual(5, grad.0.second)
72+
}
73+
74+
@differentiable(reverse)
75+
func adjust2<T: NumericDifferentiable>(multiplier: T, model: inout Model2<T>) {
76+
model.first = model.second * multiplier
77+
78+
let x = 5
79+
if x < 50 {}
80+
}
81+
82+
@differentiable(reverse)
83+
func loss3(model: Model2<Float>, multiplier: Float) -> Float {
84+
var model = model
85+
adjust2(multiplier: multiplier, model: &model)
86+
return model.first
87+
}
88+
89+
InoutControlFlowTests.test("LaterInoutParameterWithControlFlow") {
90+
var model = Model2<Float>(first: 1, second: 3)
91+
let grad = gradient(at: model, 5.0, of: loss3)
92+
expectEqual(0, grad.0.first)
93+
expectEqual(5, grad.0.second)
94+
}
95+
96+
runAllTests()

0 commit comments

Comments
 (0)