Skip to content

Use VJP in differentiation #21063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 265 additions & 2 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
using llvm::SmallSet;

static llvm::cl::opt<bool> DifferentiationUseVJP(
"differentiation-use-vjp", llvm::cl::init(false),
llvm::cl::desc("Use the VJP during differentiation"));

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -475,6 +479,10 @@ class PrimalInfo {
/// corresponding tape of its type.
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;

/// Mapping from `apply` instructions in the original function to the
/// corresponding pullback decl in the primal struct.
DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;

/// Mapping from types of control-dependent nested primal values to district
/// tapes.
DenseMap<CanType, VarDecl *> nestedTapeTypeMap;
Expand Down Expand Up @@ -571,6 +579,24 @@ class PrimalInfo {
return decl;
}

/// Add a pullback to the primal value struct.
VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) {
// Decls must have AST types (not `SILFunctionType`), so we convert the
// `SILFunctionType` of the pullback to a `FunctionType` with the same
// parameters and results.
auto *silFnTy = pullbackType->castTo<SILFunctionType>();
SmallVector<AnyFunctionType::Param, 8> params;
for (auto &param : silFnTy->getParameters())
params.push_back(AnyFunctionType::Param(param.getType()));
Type astFnTy = FunctionType::get(
params, silFnTy->getAllResultsType().getASTType());

auto *decl = addVarDecl("pullback_" + llvm::itostr(pullbackValueMap.size()),
astFnTy);
pullbackValueMap.insert({inst, decl});
return decl;
}

/// Finds the primal value decl in the primal value struct for a static primal
/// value in the original function.
VarDecl *lookupDirectStaticPrimalValueDecl(SILValue originalValue) const {
Expand All @@ -586,6 +612,14 @@ class PrimalInfo {
: lookup->getSecond();
}

/// Finds the pullback decl in the primal value struct for an `apply` in the
/// original function.
VarDecl *lookUpPullbackDecl(ApplyInst *inst) {
auto lookup = pullbackValueMap.find(inst);
return lookup == pullbackValueMap.end() ? nullptr
: lookup->getSecond();
}

/// Retrieves the tape decl in the primal value struct for the specified type.
VarDecl *getOrCreateTapeDeclForType(CanType type) {
auto &astCtx = primalValueStruct->getASTContext();
Expand Down Expand Up @@ -2390,11 +2424,139 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
SILClonerWithScopes::visitReleaseValueInst(rvi);
}

void visitApplyInst(ApplyInst *ai) {
if (DifferentiationUseVJP)
visitApplyInstWithVJP(ai);
else
visitApplyInstWithoutVJP(ai);
}

void visitApplyInstWithVJP(ApplyInst *ai) {
auto &context = getContext();
SILBuilder &builder = getBuilder();

// Special handling logic only applies when `apply` is active. If not, just
// do standard cloning.
if (!activityInfo.isActive(ai, synthesis.indices)) {
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *ai << '\n');
SILClonerWithScopes::visitApplyInst(ai);
return;
}

// This instruction is active. Replace it with a call to the VJP.

// Get the indices required for differentiating this function.
LLVM_DEBUG(getADDebugStream() << "Primal-transforming:\n" << *ai << '\n');
SmallVector<unsigned, 8> activeParamIndices;
SmallVector<unsigned, 8> activeResultIndices;
collectMinimalIndicesForFunctionCall(ai, synthesis.indices, activityInfo,
activeParamIndices,
activeResultIndices);
assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
assert(!activeResultIndices.empty() && "Result indices cannot be empty");
LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={";
interleave(activeParamIndices.begin(), activeParamIndices.end(),
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
s << "}, results={"; interleave(
activeResultIndices.begin(), activeResultIndices.end(),
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
s << "}\n";);

// FIXME: If there are mutiple active results, we don't support it yet.
if (activeResultIndices.size() > 1) {
context.emitNondifferentiabilityError(ai, synthesis.task);
errorOccurred = true;
return;
}

// Form expected indices by assuming there's only one result.
SILAutoDiffIndices indices(activeResultIndices.front(), activeParamIndices);

// Retrieve the original function being called.
auto calleeOrigin = ai->getCalleeOrigin();
auto *calleeOriginFnRef = dyn_cast<FunctionRefInst>(calleeOrigin);
// If callee does not trace back to a `function_ref`, it is an opaque
// function. Emit a "not differentiable" diagnostic here.
// FIXME: Handle `partial_apply`, `witness_method`.
if (!calleeOriginFnRef) {
context.emitNondifferentiabilityError(ai, synthesis.task);
errorOccurred = true;
return;
}

// Find or register a differentiation task for this function.
auto *newTask = context.lookUpOrRegisterDifferentiationTask(
calleeOriginFnRef->getReferencedFunction(), indices,
/*invoker*/ {ai, synthesis.task});

// Store this task so that AdjointGen can use it.
getDifferentiationTask()->getAssociatedTasks().insert({ai, newTask});

// If the task is newly created, then we need to schedule a synthesis item
// for the primal.
primalGen.lookUpPrimalAndMaybeScheduleSynthesis(newTask);

auto *vjpFn = newTask->getVJP();
assert(vjpFn);
auto *vjp = builder.createFunctionRef(ai->getCallee().getLoc(), vjpFn);

// TODO: The `visitApplyInstWithoutVJP` reapplies function conversions here,
// but all the tests seem to pass without doing that here. Investigate.

// Call the VJP using the original parameters.
SmallVector<SILValue, 8> newArgs;
auto vjpFnTy = vjpFn->getLoweredFunctionType();
auto numVJPParams = vjpFnTy->getNumParameters();
assert(vjpFnTy->getNumIndirectFormalResults() == 0 &&
"FIXME: handle vjp with indirect results");
newArgs.reserve(numVJPParams);
// Collect substituted arguments.
for (auto origArg : ai->getArguments())
newArgs.push_back(getOpValue(origArg));
assert(newArgs.size() == numVJPParams);
// Apply the VJP.
auto *vjpCall = builder.createApply(ai->getLoc(), vjp,
ai->getSubstitutionMap(), newArgs,
ai->isNonThrowing());
LLVM_DEBUG(getADDebugStream()
<< "Applied vjp function\n" << *vjpCall);

// Get the VJP results (original results and pullback).
SmallVector<SILValue, 8> vjpDirectResults;
extractAllElements(vjpCall, builder, vjpDirectResults);
ArrayRef<SILValue> originalDirectResults =
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
SILValue originalDirectResult = joinElements(originalDirectResults,
builder,
vjpCall->getLoc());
SILValue pullback = vjpDirectResults.back();

// Store the original result to the value map.
ValueMap.insert({ai, originalDirectResult});

// Checkpoint the original results.
getPrimalInfo().addStaticPrimalValueDecl(ai);
staticPrimalValues.push_back(originalDirectResult);

// Checkpoint the pullback.
getPrimalInfo().addPullbackDecl(ai, pullback->getType().getASTType());
staticPrimalValues.push_back(pullback);

// Some instructions that produce the callee may have been cloned.
// If the original callee did not have any users beyond this `apply`,
// recursively kill the cloned callee.
if (auto *origCallee = cast_or_null<SingleValueInstruction>(
ai->getCallee()->getDefiningInstruction()))
if (origCallee->hasOneUse())
recursivelyDeleteTriviallyDeadInstructions(
getOpValue(origCallee)->getDefiningInstruction());
}

/// Handle the primal transformation of an `apply` instruction. We do not
/// always transform `apply`. When we do, we do not just blindly differentiate
/// from all results w.r.t. all parameters. Instead, we let activity analysis
/// decide whether to transform and what differentiation indices to use.
void visitApplyInst(ApplyInst *ai) {
void visitApplyInstWithoutVJP(ApplyInst *ai) {
// Special handling logic only applies when `apply` is active. If not, just
// do standard cloning.
if (!activityInfo.isActive(ai, synthesis.indices)) {
Expand Down Expand Up @@ -3292,9 +3454,110 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
return rematCloner.getMappedValue(value);
}

void visitApplyInst(ApplyInst *ai) {
if (DifferentiationUseVJP)
visitApplyInstWithVJP(ai);
else
visitApplyInstWithoutVJP(ai);
}

void visitApplyInstWithVJP(ApplyInst *ai) {
// Replace a call to a function with a call to its pullback.

auto &builder = getBuilder();
auto loc = remapLocation(ai->getLoc());

// Look for the task that differentiates the callee.
auto &assocTasks = getDifferentiationTask()->getAssociatedTasks();
auto assocTaskLookUp = assocTasks.find(ai);
// If no task was found, then this task doesn't need to be differentiated.
if (assocTaskLookUp == assocTasks.end()) {
// Must not be active.
assert(
!activityInfo.isActive(ai, getDifferentiationTask()->getIndices()));
return;
}
auto *otherTask = assocTaskLookUp->getSecond();
auto origTy = otherTask->getOriginal()->getLoweredFunctionType();
SILFunctionConventions origConvs(origTy, getModule());

// Get the pullback.
auto *field = getPrimalInfo().lookUpPullbackDecl(ai);
assert(field);
SILValue pullback = builder.createStructExtract(remapLocation(ai->getLoc()),
primalValueAggregateInAdj,
field);

// Construct the pullback arguments.
SmallVector<SILValue, 8> args;
auto seed = getAdjointValue(ai);
auto *seedBuf = getBuilder().createAllocStack(loc, seed.getType());
materializeAdjointIndirect(seed, seedBuf);
if (seed.getType().isAddressOnly(getModule()))
args.push_back(seedBuf);
else {
auto access = getBuilder().createBeginAccess(
loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
/*noNestedConflict*/ true,
/*fromBuiltin*/ false);
args.push_back(getBuilder().createLoad(
loc, access, getBufferLOQ(seed.getSwiftType(), getAdjoint())));
getBuilder().createEndAccess(loc, access, /*aborted*/ false);
}

// Call the pullback.
auto *pullbackCall = builder.createApply(ai->getLoc(), pullback,
SubstitutionMap(), args,
/*isNonThrowing*/ false);

// Clean up seed allocation.
getBuilder().createDeallocStack(loc, seedBuf);

// If `pullbackCall` is a tuple, extract all results.
SmallVector<SILValue, 8> dirResults;
extractAllElements(pullbackCall, builder, dirResults);
// Get all results in type-defined order.
SmallVector<SILValue, 8> allResults;
collectAllActualResultsInTypeOrder(
pullbackCall, dirResults, pullbackCall->getIndirectSILResults(),
allResults);
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "All direct results of the nested pullback call: \n";
llvm::for_each(dirResults, [&](SILValue v) { s << v; });
s << "All indirect results of the nested pullback call: \n";
llvm::for_each(pullbackCall->getIndirectSILResults(),
[&](SILValue v) { s << v; });
s << "All results of the nested pullback call: \n";
llvm::for_each(allResults, [&](SILValue v) { s << v; });
});

// Set adjoints for all original parameters.
auto originalParams = ai->getArgumentsWithoutIndirectResults();
auto origNumIndRes = origConvs.getNumIndirectSILResults();
auto allResultsIt = allResults.begin();
// If the applied adjoint returns the adjoint of the original self
// parameter, then it returns it first. Set the adjoint of the original
// self parameter.
auto selfParamIndex = originalParams.size() - 1;
if (ai->hasSelfArgument() &&
otherTask->getIndices().isWrtParameter(selfParamIndex))
addAdjointValue(ai->getArgument(origNumIndRes + selfParamIndex),
AdjointValue::getMaterialized(*allResultsIt++));
// Set adjoints for the remaining non-self original parameters.
for (unsigned i : otherTask->getIndices().parameters.set_bits()) {
// Do not set the adjoint of the original self parameter because we
// already added it at the beginning.
if (ai->hasSelfArgument() && i == selfParamIndex)
continue;
addAdjointValue(ai->getArgument(origNumIndRes + i),
AdjointValue::getMaterialized(*allResultsIt++));
}
}

/// Handle `apply` instruction. If it's active (on the differentiation path),
/// we replace it with its corresponding adjoint.
void visitApplyInst(ApplyInst *ai) {
void visitApplyInstWithoutVJP(ApplyInst *ai) {
// Replace a call to the function with a call to its adjoint.
auto &assocTasks = getDifferentiationTask()->getAssociatedTasks();
auto assocTaskLookUp = assocTasks.find(ai);
Expand Down
1 change: 1 addition & 0 deletions test/AutoDiff/autodiff_e2e_basic.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
// RUN: %target-swift-frontend -Xllvm -differentiation-use-vjp -emit-sil %s | %FileCheck %s

@differentiable(reverse, adjoint: adjointId)
func id(_ x: Float) -> Float {
Expand Down
1 change: 1 addition & 0 deletions test/AutoDiff/method.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-use-vjp-swift
// REQUIRES: executable_test

import StdlibUnittest
Expand Down
Loading