Skip to content

[AutoDiff] Properly collect inout parameter adjoints #41559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2085,37 +2085,38 @@ bool PullbackCloner::Implementation::run() {

// Collect differentiation parameter adjoints.
// Do a first pass to collect non-inout values.
unsigned pullbackInoutArgumentIndex = 0;
for (auto i : getConfig().parameterIndices->getIndices()) {
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
if (!isParameterInout) {
if (!conv.getParameters()[i].isIndirectMutating()) {
addRetElt(i);
}
}

// Do a second pass for all inout parameters, however this is only necessary
// for functions with multiple basic blocks. For functions with a single
// basic block adjoint accumulation for those parameters is already done by
// per-instruction visitors.
if (getOriginal().size() > 1) {
const auto &pullbackConv = pullback.getConventions();
SmallVector<SILArgument *, 1> pullbackInOutArgs;
for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) {
if (pullbackConv.getParameters()[pullbackArg.index()].isIndirectMutating())
pullbackInOutArgs.push_back(pullbackArg.value());
}

unsigned pullbackInoutArgumentIdx = 0;
for (auto i : getConfig().parameterIndices->getIndices()) {
// Skip non-inout parameters.
if (!conv.getParameters()[i].isIndirectMutating())
continue;

// For functions with multiple basic blocks, accumulation is needed
// for `inout` parameters because pullback basic blocks have different
// adjoint buffers.
pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]);
addRetElt(i);
}
}

// Do a second pass for all inout parameters.
for (auto i : getConfig().parameterIndices->getIndices()) {
// Skip non-inout parameters.
auto isParameterInout = conv.getParameters()[i].isIndirectMutating();
if (!isParameterInout)
continue;

// Skip `inout` parameters for functions with a single basic block:
// adjoint accumulation for those parameters is already done by
// per-instruction visitors.
if (getOriginal().size() == 1)
continue;

// For functions with multiple basic blocks, accumulation is needed
// for `inout` parameters because pullback basic blocks have different
// adjoint buffers.
auto pullbackInoutArgument =
getPullback()
.getArgumentsWithoutIndirectResults()[pullbackInoutArgumentIndex++];
pullbackIndirectResults.push_back(pullbackInoutArgument);
addRetElt(i);
}

// Copy them to adjoint indirect results.
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
"Indirect parameter adjoint count mismatch");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

// SR-15891: The parameter indices used for copying
// inout tangent vectors were calculated improperly
// in presence of other pullback parameters (e.g.
// captures)

import _Differentiation

struct Foo {
var bar : Float
var baz : Float
var name : String?
}

func outerFunc(doIterations : Int, value: inout Float) -> (Float, (Float) -> Float) {
@differentiable(reverse, wrt: param)
func innerFunc1(param: Float, other: Foo) -> Float {
value += param * other.bar
return value * param * 2.0
}

@differentiable(reverse, wrt: param1)
func loop(param1 : Float, other1: Foo) -> Float {
var res : Float;
res = 0.0
if (doIterations > 0) {
res = innerFunc1(param: param1, other: other1)
}

return res
}

@differentiable(reverse)
func curriedFunc(param: Float) -> Float {
let other = Foo(bar: 7, baz: 9)
return loop(param1: param, other1: other)
}

let valAndPullback = valueWithPullback(at: value, of: curriedFunc)
return (value + valAndPullback.value, valAndPullback.pullback)
}