Skip to content

[AutoDiff] [SR-14218] Correctly propagate tangent vectors of inout parameters from functions with multiple basic blocks #37861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,8 @@ bool PullbackCloner::Implementation::run() {
SmallVector<SILValue, 8> retElts;
// This vector will contain all indirect parameter adjoint buffers.
SmallVector<SILValue, 4> indParamAdjoints;
// This vector will identify the locations where initialization is needed.
SmallBitVector outputsToInitialize;

auto conv = getOriginal().getConventions();
auto origParams = getOriginal().getArgumentsWithoutIndirectResults();
Expand All @@ -2071,25 +2073,62 @@ bool PullbackCloner::Implementation::run() {
case SILValueCategory::Address: {
auto adjBuf = getAdjointBuffer(origEntry, origParam);
indParamAdjoints.push_back(adjBuf);
outputsToInitialize.push_back(
!conv.getParameters()[parameterIndex].isIndirectMutating());
break;
}
}
};
SmallVector<SILArgument *, 4> pullbackIndirectResults(
getPullback().getIndirectResults().begin(),
getPullback().getIndirectResults().end());

// Collect differentiation parameter adjoints.
// Do a first pass to collect non-inout values.
unsigned pullbackInoutArgumentIndex = 0;
for (auto i : getConfig().parameterIndices->getIndices()) {
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
if (!isParameterInout) {
addRetElt(i);
}
}

// Do a second pass for all inout parameters.
for (auto i : getConfig().parameterIndices->getIndices()) {
// Skip `inout` parameters.
if (conv.getParameters()[i].isIndirectMutating())
// Skip non-inout parameters.
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
if (!isParameterInout)
continue;

// Skip `inout` parameters for functions with a single basic block:
// adjoint accumulation for those parameters is already done by
// per-instruction visitors.
if (getOriginal().size() == 1)
continue;

// For functions with multiple basic blocks, accumulation is needed
// for `inout` parameters because pullback basic blocks have different
// adjoint buffers.
auto pullbackInoutArgument =
getPullback()
.getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++];
pullbackIndirectResults.push_back(pullbackInoutArgument);
addRetElt(i);
}

// Copy them to adjoint indirect results.
assert(indParamAdjoints.size() == getPullback().getIndirectResults().size() &&
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
"Indirect parameter adjoint count mismatch");
for (auto pair : zip(indParamAdjoints, getPullback().getIndirectResults())) {
unsigned currentIndex = 0;
for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) {
auto source = std::get<0>(pair);
auto *dest = std::get<1>(pair);
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
if (outputsToInitialize[currentIndex]) {
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
} else {
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
}
currentIndex++;
// Prevent source buffer from being deallocated, since the underlying
// value is moved.
destroyedLocalAllocations.insert(source);
Expand Down
88 changes: 88 additions & 0 deletions test/AutoDiff/validation-test/inout_control_flow.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

import StdlibUnittest
import _Differentiation

var InoutControlFlowTests = TestSuite("InoutControlFlow")

// SR-14218
struct Model: Differentiable {
var first: Float = 3
var second: Float = 1

mutating func outer() {
inner()
}

mutating func inner() {
self.second = self.first

// Dummy no-op if block, required to introduce control flow.
let x = 5
if x < 50 {}
}
}

@differentiable(reverse)
func loss(model: Model) -> Float{
var model = model
model.outer()
return model.second
}

InoutControlFlowTests.test("MutatingBeforeControlFlow") {
var model = Model()
let grad = gradient(at: model, of: loss)
expectEqual(1, grad.first)
expectEqual(0, grad.second)
}

// SR-14053
@differentiable(reverse)
func adjust(model: inout Model, multiplier: Float) {
model.first = model.second * multiplier

// Dummy no-op if block, required to introduce control flow.
let x = 5
if x < 50 {}
}

@differentiable(reverse)
func loss2(model: Model, multiplier: Float) -> Float {
var model = model
adjust(model: &model, multiplier: multiplier)
return model.first
}

InoutControlFlowTests.test("InoutParameterWithControlFlow") {
var model = Model(first: 1, second: 3)
let grad = gradient(at: model, 5.0, of: loss2)
expectEqual(0, grad.0.first)
expectEqual(5, grad.0.second)
}

@differentiable(reverse)
func adjust2(multiplier: Float, model: inout Model) {
model.first = model.second * multiplier

// Dummy no-op if block, required to introduce control flow.
let x = 5
if x < 50 {}
}

@differentiable(reverse)
func loss3(model: Model, multiplier: Float) -> Float {
var model = model
adjust2(multiplier: multiplier, model: &model)
return model.first
}

InoutControlFlowTests.test("LaterInoutParameterWithControlFlow") {
var model = Model(first: 1, second: 3)
let grad = gradient(at: model, 5.0, of: loss3)
expectEqual(0, grad.0.first)
expectEqual(5, grad.0.second)
}

runAllTests()