Skip to content

Commit c91a626

Browse files
authored
Use VJP in differentiation (#21063)
1 parent 033f09a commit c91a626

File tree

8 files changed

+345
-36
lines changed

8 files changed

+345
-36
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 265 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ using llvm::SmallDenseMap;
5454
using llvm::SmallDenseSet;
5555
using llvm::SmallSet;
5656

57+
static llvm::cl::opt<bool> DifferentiationUseVJP(
58+
"differentiation-use-vjp", llvm::cl::init(false),
59+
llvm::cl::desc("Use the VJP during differentiation"));
60+
5761
//===----------------------------------------------------------------------===//
5862
// Helpers
5963
//===----------------------------------------------------------------------===//
@@ -475,6 +479,10 @@ class PrimalInfo {
475479
/// corresponding tape of its type.
476480
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;
477481

482+
/// Mapping from `apply` instructions in the original function to the
483+
/// corresponding pullback decl in the primal struct.
484+
DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
485+
478486
/// Mapping from types of control-dependent nested primal values to district
479487
/// tapes.
480488
DenseMap<CanType, VarDecl *> nestedTapeTypeMap;
@@ -571,6 +579,24 @@ class PrimalInfo {
571579
return decl;
572580
}
573581

582+
/// Add a pullback to the primal value struct.
583+
VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) {
584+
// Decls must have AST types (not `SILFunctionType`), so we convert the
585+
// `SILFunctionType` of the pullback to a `FunctionType` with the same
586+
// parameters and results.
587+
auto *silFnTy = pullbackType->castTo<SILFunctionType>();
588+
SmallVector<AnyFunctionType::Param, 8> params;
589+
for (auto &param : silFnTy->getParameters())
590+
params.push_back(AnyFunctionType::Param(param.getType()));
591+
Type astFnTy = FunctionType::get(
592+
params, silFnTy->getAllResultsType().getASTType());
593+
594+
auto *decl = addVarDecl("pullback_" + llvm::itostr(pullbackValueMap.size()),
595+
astFnTy);
596+
pullbackValueMap.insert({inst, decl});
597+
return decl;
598+
}
599+
574600
/// Finds the primal value decl in the primal value struct for a static primal
575601
/// value in the original function.
576602
VarDecl *lookupDirectStaticPrimalValueDecl(SILValue originalValue) const {
@@ -586,6 +612,14 @@ class PrimalInfo {
586612
: lookup->getSecond();
587613
}
588614

615+
/// Finds the pullback decl in the primal value struct for an `apply` in the
616+
/// original function.
617+
VarDecl *lookUpPullbackDecl(ApplyInst *inst) {
618+
auto lookup = pullbackValueMap.find(inst);
619+
return lookup == pullbackValueMap.end() ? nullptr
620+
: lookup->getSecond();
621+
}
622+
589623
/// Retrieves the tape decl in the primal value struct for the specified type.
590624
VarDecl *getOrCreateTapeDeclForType(CanType type) {
591625
auto &astCtx = primalValueStruct->getASTContext();
@@ -2390,11 +2424,139 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
23902424
SILClonerWithScopes::visitReleaseValueInst(rvi);
23912425
}
23922426

2427+
void visitApplyInst(ApplyInst *ai) {
2428+
if (DifferentiationUseVJP)
2429+
visitApplyInstWithVJP(ai);
2430+
else
2431+
visitApplyInstWithoutVJP(ai);
2432+
}
2433+
2434+
void visitApplyInstWithVJP(ApplyInst *ai) {
2435+
auto &context = getContext();
2436+
SILBuilder &builder = getBuilder();
2437+
2438+
// Special handling logic only applies when `apply` is active. If not, just
2439+
// do standard cloning.
2440+
if (!activityInfo.isActive(ai, synthesis.indices)) {
2441+
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *ai << '\n');
2442+
SILClonerWithScopes::visitApplyInst(ai);
2443+
return;
2444+
}
2445+
2446+
// This instruction is active. Replace it with a call to the VJP.
2447+
2448+
// Get the indices required for differentiating this function.
2449+
LLVM_DEBUG(getADDebugStream() << "Primal-transforming:\n" << *ai << '\n');
2450+
SmallVector<unsigned, 8> activeParamIndices;
2451+
SmallVector<unsigned, 8> activeResultIndices;
2452+
collectMinimalIndicesForFunctionCall(ai, synthesis.indices, activityInfo,
2453+
activeParamIndices,
2454+
activeResultIndices);
2455+
assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
2456+
assert(!activeResultIndices.empty() && "Result indices cannot be empty");
2457+
LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={";
2458+
interleave(activeParamIndices.begin(), activeParamIndices.end(),
2459+
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
2460+
s << "}, results={"; interleave(
2461+
activeResultIndices.begin(), activeResultIndices.end(),
2462+
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
2463+
s << "}\n";);
2464+
2465+
// FIXME: If there are mutiple active results, we don't support it yet.
2466+
if (activeResultIndices.size() > 1) {
2467+
context.emitNondifferentiabilityError(ai, synthesis.task);
2468+
errorOccurred = true;
2469+
return;
2470+
}
2471+
2472+
// Form expected indices by assuming there's only one result.
2473+
SILAutoDiffIndices indices(activeResultIndices.front(), activeParamIndices);
2474+
2475+
// Retrieve the original function being called.
2476+
auto calleeOrigin = ai->getCalleeOrigin();
2477+
auto *calleeOriginFnRef = dyn_cast<FunctionRefInst>(calleeOrigin);
2478+
// If callee does not trace back to a `function_ref`, it is an opaque
2479+
// function. Emit a "not differentiable" diagnostic here.
2480+
// FIXME: Handle `partial_apply`, `witness_method`.
2481+
if (!calleeOriginFnRef) {
2482+
context.emitNondifferentiabilityError(ai, synthesis.task);
2483+
errorOccurred = true;
2484+
return;
2485+
}
2486+
2487+
// Find or register a differentiation task for this function.
2488+
auto *newTask = context.lookUpOrRegisterDifferentiationTask(
2489+
calleeOriginFnRef->getReferencedFunction(), indices,
2490+
/*invoker*/ {ai, synthesis.task});
2491+
2492+
// Store this task so that AdjointGen can use it.
2493+
getDifferentiationTask()->getAssociatedTasks().insert({ai, newTask});
2494+
2495+
// If the task is newly created, then we need to schedule a synthesis item
2496+
// for the primal.
2497+
primalGen.lookUpPrimalAndMaybeScheduleSynthesis(newTask);
2498+
2499+
auto *vjpFn = newTask->getVJP();
2500+
assert(vjpFn);
2501+
auto *vjp = builder.createFunctionRef(ai->getCallee().getLoc(), vjpFn);
2502+
2503+
// TODO: The `visitApplyInstWithoutVJP` reapplies function conversions here,
2504+
// but all the tests seem to pass without doing that here. Investigate.
2505+
2506+
// Call the VJP using the original parameters.
2507+
SmallVector<SILValue, 8> newArgs;
2508+
auto vjpFnTy = vjpFn->getLoweredFunctionType();
2509+
auto numVJPParams = vjpFnTy->getNumParameters();
2510+
assert(vjpFnTy->getNumIndirectFormalResults() == 0 &&
2511+
"FIXME: handle vjp with indirect results");
2512+
newArgs.reserve(numVJPParams);
2513+
// Collect substituted arguments.
2514+
for (auto origArg : ai->getArguments())
2515+
newArgs.push_back(getOpValue(origArg));
2516+
assert(newArgs.size() == numVJPParams);
2517+
// Apply the VJP.
2518+
auto *vjpCall = builder.createApply(ai->getLoc(), vjp,
2519+
ai->getSubstitutionMap(), newArgs,
2520+
ai->isNonThrowing());
2521+
LLVM_DEBUG(getADDebugStream()
2522+
<< "Applied vjp function\n" << *vjpCall);
2523+
2524+
// Get the VJP results (original results and pullback).
2525+
SmallVector<SILValue, 8> vjpDirectResults;
2526+
extractAllElements(vjpCall, builder, vjpDirectResults);
2527+
ArrayRef<SILValue> originalDirectResults =
2528+
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
2529+
SILValue originalDirectResult = joinElements(originalDirectResults,
2530+
builder,
2531+
vjpCall->getLoc());
2532+
SILValue pullback = vjpDirectResults.back();
2533+
2534+
// Store the original result to the value map.
2535+
ValueMap.insert({ai, originalDirectResult});
2536+
2537+
// Checkpoint the original results.
2538+
getPrimalInfo().addStaticPrimalValueDecl(ai);
2539+
staticPrimalValues.push_back(originalDirectResult);
2540+
2541+
// Checkpoint the pullback.
2542+
getPrimalInfo().addPullbackDecl(ai, pullback->getType().getASTType());
2543+
staticPrimalValues.push_back(pullback);
2544+
2545+
// Some instructions that produce the callee may have been cloned.
2546+
// If the original callee did not have any users beyond this `apply`,
2547+
// recursively kill the cloned callee.
2548+
if (auto *origCallee = cast_or_null<SingleValueInstruction>(
2549+
ai->getCallee()->getDefiningInstruction()))
2550+
if (origCallee->hasOneUse())
2551+
recursivelyDeleteTriviallyDeadInstructions(
2552+
getOpValue(origCallee)->getDefiningInstruction());
2553+
}
2554+
23932555
/// Handle the primal transformation of an `apply` instruction. We do not
23942556
/// always transform `apply`. When we do, we do not just blindly differentiate
23952557
/// from all results w.r.t. all parameters. Instead, we let activity analysis
23962558
/// decide whether to transform and what differentiation indices to use.
2397-
void visitApplyInst(ApplyInst *ai) {
2559+
void visitApplyInstWithoutVJP(ApplyInst *ai) {
23982560
// Special handling logic only applies when `apply` is active. If not, just
23992561
// do standard cloning.
24002562
if (!activityInfo.isActive(ai, synthesis.indices)) {
@@ -3292,9 +3454,110 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
32923454
return rematCloner.getMappedValue(value);
32933455
}
32943456

3457+
void visitApplyInst(ApplyInst *ai) {
3458+
if (DifferentiationUseVJP)
3459+
visitApplyInstWithVJP(ai);
3460+
else
3461+
visitApplyInstWithoutVJP(ai);
3462+
}
3463+
3464+
void visitApplyInstWithVJP(ApplyInst *ai) {
3465+
// Replace a call to a function with a call to its pullback.
3466+
3467+
auto &builder = getBuilder();
3468+
auto loc = remapLocation(ai->getLoc());
3469+
3470+
// Look for the task that differentiates the callee.
3471+
auto &assocTasks = getDifferentiationTask()->getAssociatedTasks();
3472+
auto assocTaskLookUp = assocTasks.find(ai);
3473+
// If no task was found, then this task doesn't need to be differentiated.
3474+
if (assocTaskLookUp == assocTasks.end()) {
3475+
// Must not be active.
3476+
assert(
3477+
!activityInfo.isActive(ai, getDifferentiationTask()->getIndices()));
3478+
return;
3479+
}
3480+
auto *otherTask = assocTaskLookUp->getSecond();
3481+
auto origTy = otherTask->getOriginal()->getLoweredFunctionType();
3482+
SILFunctionConventions origConvs(origTy, getModule());
3483+
3484+
// Get the pullback.
3485+
auto *field = getPrimalInfo().lookUpPullbackDecl(ai);
3486+
assert(field);
3487+
SILValue pullback = builder.createStructExtract(remapLocation(ai->getLoc()),
3488+
primalValueAggregateInAdj,
3489+
field);
3490+
3491+
// Construct the pullback arguments.
3492+
SmallVector<SILValue, 8> args;
3493+
auto seed = getAdjointValue(ai);
3494+
auto *seedBuf = getBuilder().createAllocStack(loc, seed.getType());
3495+
materializeAdjointIndirect(seed, seedBuf);
3496+
if (seed.getType().isAddressOnly(getModule()))
3497+
args.push_back(seedBuf);
3498+
else {
3499+
auto access = getBuilder().createBeginAccess(
3500+
loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3501+
/*noNestedConflict*/ true,
3502+
/*fromBuiltin*/ false);
3503+
args.push_back(getBuilder().createLoad(
3504+
loc, access, getBufferLOQ(seed.getSwiftType(), getAdjoint())));
3505+
getBuilder().createEndAccess(loc, access, /*aborted*/ false);
3506+
}
3507+
3508+
// Call the pullback.
3509+
auto *pullbackCall = builder.createApply(ai->getLoc(), pullback,
3510+
SubstitutionMap(), args,
3511+
/*isNonThrowing*/ false);
3512+
3513+
// Clean up seed allocation.
3514+
getBuilder().createDeallocStack(loc, seedBuf);
3515+
3516+
// If `pullbackCall` is a tuple, extract all results.
3517+
SmallVector<SILValue, 8> dirResults;
3518+
extractAllElements(pullbackCall, builder, dirResults);
3519+
// Get all results in type-defined order.
3520+
SmallVector<SILValue, 8> allResults;
3521+
collectAllActualResultsInTypeOrder(
3522+
pullbackCall, dirResults, pullbackCall->getIndirectSILResults(),
3523+
allResults);
3524+
LLVM_DEBUG({
3525+
auto &s = getADDebugStream();
3526+
s << "All direct results of the nested pullback call: \n";
3527+
llvm::for_each(dirResults, [&](SILValue v) { s << v; });
3528+
s << "All indirect results of the nested pullback call: \n";
3529+
llvm::for_each(pullbackCall->getIndirectSILResults(),
3530+
[&](SILValue v) { s << v; });
3531+
s << "All results of the nested pullback call: \n";
3532+
llvm::for_each(allResults, [&](SILValue v) { s << v; });
3533+
});
3534+
3535+
// Set adjoints for all original parameters.
3536+
auto originalParams = ai->getArgumentsWithoutIndirectResults();
3537+
auto origNumIndRes = origConvs.getNumIndirectSILResults();
3538+
auto allResultsIt = allResults.begin();
3539+
// If the applied adjoint returns the adjoint of the original self
3540+
// parameter, then it returns it first. Set the adjoint of the original
3541+
// self parameter.
3542+
auto selfParamIndex = originalParams.size() - 1;
3543+
if (ai->hasSelfArgument() &&
3544+
otherTask->getIndices().isWrtParameter(selfParamIndex))
3545+
addAdjointValue(ai->getArgument(origNumIndRes + selfParamIndex),
3546+
AdjointValue::getMaterialized(*allResultsIt++));
3547+
// Set adjoints for the remaining non-self original parameters.
3548+
for (unsigned i : otherTask->getIndices().parameters.set_bits()) {
3549+
// Do not set the adjoint of the original self parameter because we
3550+
// already added it at the beginning.
3551+
if (ai->hasSelfArgument() && i == selfParamIndex)
3552+
continue;
3553+
addAdjointValue(ai->getArgument(origNumIndRes + i),
3554+
AdjointValue::getMaterialized(*allResultsIt++));
3555+
}
3556+
}
3557+
32953558
/// Handle `apply` instruction. If it's active (on the differentiation path),
32963559
/// we replace it with its corresponding adjoint.
3297-
void visitApplyInst(ApplyInst *ai) {
3560+
void visitApplyInstWithoutVJP(ApplyInst *ai) {
32983561
// Replace a call to the function with a call to its adjoint.
32993562
auto &assocTasks = getDifferentiationTask()->getAssociatedTasks();
33003563
auto assocTaskLookUp = assocTasks.find(ai);

test/AutoDiff/autodiff_e2e_basic.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
// RUN: %target-swift-frontend -Xllvm -differentiation-use-vjp -emit-sil %s | %FileCheck %s
23

34
@differentiable(reverse, adjoint: adjointId)
45
func id(_ x: Float) -> Float {

test/AutoDiff/method.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %target-run-simple-swift
2+
// RUN: %target-run-use-vjp-swift
23
// REQUIRES: executable_test
34

45
import StdlibUnittest

0 commit comments

Comments
 (0)