Skip to content

Commit c69d291

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Eliminate SIL linking from differentiation. (#21973)
* [AutoDiff] Eliminate SIL linking from differentiation. The differentiation pass no longer depends on SerializedSILLoader to link JVP/VJP functions or witness tables. Background: the differentiation pass guarantees that JVP/VJPs exist for `@differentiable` functions defined in a module. If the original function has public linkage, then JVP/VJPs also have public linkage. If JVP/VJPs cannot be found in the current module for a given original function, create an empty declaration to them with public external linkage. No deserialization is necessary. A similar technique has been used for witness methods: we emit `witness_method` instructions rather than loading witness tables. Defensive explicit loading of witness tables is removed. --- This enables differentiation of functions from other modules without requiring that JVP/VJPS be `@inlinable`. This is necessary for library development. Joint work with @rxwei. * Declare `declareExternalAssociatedFunction` helper function. Address comments by @rxwei. * Minor edits. Address comments by @rxwei and @slavapestov. * Clean up differentiation tasks when error occurs. Fix crasher.
1 parent 0d48295 commit c69d291

File tree

1 file changed

+109
-55
lines changed

1 file changed

+109
-55
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 109 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#include "swift/AST/Module.h"
3434
#include "swift/AST/ParameterList.h"
3535
#include "swift/AST/SubstitutionMap.h"
36-
#include "swift/Serialization/SerializedSILLoader.h"
3736
#include "swift/SIL/FormalLinkage.h"
3837
#include "swift/SIL/LoopInfo.h"
3938
#include "swift/SIL/SILBuilder.h"
@@ -856,10 +855,14 @@ class ADContext {
856855
void clearTask(DifferentiationTask *task) {
857856
LLVM_DEBUG(getADDebugStream() << "Clearing differentiation task for "
858857
<< task->original->getName() << '\n');
859-
transform.notifyWillDeleteFunction(task->primal);
860-
module.eraseFunction(task->primal);
861-
transform.notifyWillDeleteFunction(task->adjoint);
862-
module.eraseFunction(task->adjoint);
858+
if (task->primal) {
859+
transform.notifyWillDeleteFunction(task->primal);
860+
module.eraseFunction(task->primal);
861+
}
862+
if (task->adjoint) {
863+
transform.notifyWillDeleteFunction(task->adjoint);
864+
module.eraseFunction(task->adjoint);
865+
}
863866
transform.notifyWillDeleteFunction(task->jvp);
864867
module.eraseFunction(task->jvp);
865868
transform.notifyWillDeleteFunction(task->vjp);
@@ -980,6 +983,14 @@ class ADContext {
980983
return differentiationTasks.back().get();
981984
}
982985

986+
/// Declare an external reference to an associated function of `original`,
987+
/// given a `[differentiable]` attribute of `original` and the associated
988+
/// function kind.
989+
SILFunction *
990+
declareExternalAssociatedFunction(SILFunction *original,
991+
SILDifferentiableAttr *attr,
992+
AutoDiffAssociatedFunctionKind kind);
993+
983994
template <typename... T, typename... U>
984995
InFlightDiagnostic diagnose(SourceLoc loc, Diag<T...> diag,
985996
U &&... args) const {
@@ -1006,11 +1017,7 @@ class ADContext {
10061017

10071018
ADContext::ADContext(SILModuleTransform &transform)
10081019
: transform(transform), module(*transform.getModule()),
1009-
passManager(*transform.getPassManager()) {
1010-
// Note: `getSILLoader` performs important initialization and is necessary to
1011-
// prevent test failures related to `lookUpFunctionInWitnessTable`.
1012-
(void)module.getSILLoader();
1013-
}
1020+
passManager(*transform.getPassManager()) {}
10141021

10151022
void ADContext::emitNondifferentiabilityError(SILValue value,
10161023
const DifferentiationTask *task,
@@ -2261,7 +2268,7 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22612268
} // end anonymous namespace
22622269

22632270
bool PrimalGen::performSynthesis(FunctionSynthesisItem item) {
2264-
LLVM_DEBUG(getADDebugStream() << "Performing primal synthesis for original"
2271+
LLVM_DEBUG(getADDebugStream() << "Performing primal synthesis for original "
22652272
<< item.original->getName() << " and its corresponding primal "
22662273
<< item.target->getName() << '\n');
22672274
// FIXME: If the original function has multiple basic blocks, bail out since
@@ -2314,8 +2321,8 @@ bool PrimalGen::run() {
23142321
auto synthesis = worklist.back();
23152322
worklist.pop_back();
23162323
if (performSynthesis(synthesis)) {
2317-
context.clearTask(synthesis.task);
23182324
errorOccurred = true;
2325+
continue;
23192326
}
23202327
synthesis.task->getPrimalInfo()->computePrimalValueStructType();
23212328
synthesis.task->setPrimalSynthesisState(FunctionSynthesisState::Done);
@@ -2373,8 +2380,8 @@ bool AdjointGen::run() {
23732380
auto synthesis = worklist.back();
23742381
worklist.pop_back();
23752382
if (performSynthesis(synthesis)) {
2376-
context.clearTask(synthesis.task);
23772383
errorOccurred = true;
2384+
continue;
23782385
}
23792386
synthesis.task->setAdjointSynthesisState(FunctionSynthesisState::Done);
23802387
}
@@ -3301,8 +3308,6 @@ void AdjointEmitter::materializeZeroIndirect(CanType type,
33013308
// %wm = witness_method ...
33023309
auto *getter = builder.createWitnessMethod(loc, type, confRef,
33033310
accessorDeclRef, methodType);
3304-
// Ensure that the witness table is linked.
3305-
(void)getModule().lookUpFunctionInWitnessTable(confRef, accessorDeclRef);
33063311
// %metatype = metatype $T
33073312
auto metatypeType = CanMetatypeType::get(type, MetatypeRepresentation::Thick);
33083313
auto metatype = builder.createMetatype(
@@ -3594,8 +3599,6 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
35943599
// %0 = witness_method @+
35953600
auto witnessMethod = builder.createWitnessMethod(loc, adjointASTTy,
35963601
confRef, declRef, silFnTy);
3597-
// Ensure the witness method is linked.
3598-
getModule().lookUpFunctionInWitnessTable(confRef, declRef);
35993602
auto subMap =
36003603
SubstitutionMap::getProtocolSubstitutions(proto, adjointASTTy, confRef);
36013604
// %1 = metatype $T.Type
@@ -3623,7 +3626,7 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
36233626
}
36243627

36253628
bool AdjointGen::performSynthesis(FunctionSynthesisItem item) {
3626-
LLVM_DEBUG(getADDebugStream() << "Performing adjoint synthesis for original"
3629+
LLVM_DEBUG(getADDebugStream() << "Performing adjoint synthesis for original "
36273630
<< item.original->getName() << " and its corresponding adjoint "
36283631
<< item.target->getName() << '\n');
36293632
auto &passManager = context.getPassManager();
@@ -3639,25 +3642,90 @@ bool AdjointGen::performSynthesis(FunctionSynthesisItem item) {
36393642
// DifferentiationTask
36403643
//===----------------------------------------------------------------------===//
36413644

3645+
// Return the expected generic signature for autodiff associated functions given
3646+
// a SILDifferentiableAttr. The expected generic signature is built from the
3647+
// original generic signature and the attribute's requirements.
3648+
static CanGenericSignature
3649+
getAutoDiffAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
3650+
SILFunction *original) {
3651+
auto originalGenSig =
3652+
original->getLoweredFunctionType()->getGenericSignature();
3653+
if (!originalGenSig)
3654+
return nullptr;
3655+
GenericSignatureBuilder builder(original->getASTContext());
3656+
// Add original generic signature.
3657+
builder.addGenericSignature(originalGenSig);
3658+
// Add where clause requirements.
3659+
auto source =
3660+
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
3661+
for (auto &req : attr->getRequirements())
3662+
builder.addRequirement(req, source, original->getModule().getSwiftModule());
3663+
return std::move(builder)
3664+
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams=*/true)
3665+
->getCanonicalSignature();
3666+
}
3667+
3668+
SILFunction *
3669+
ADContext::declareExternalAssociatedFunction(
3670+
SILFunction *original, SILDifferentiableAttr *attr,
3671+
AutoDiffAssociatedFunctionKind kind) {
3672+
auto &module = getModule();
3673+
auto &indices = attr->getIndices();
3674+
auto originalTy = original->getLoweredFunctionType();
3675+
auto originalLoc = original->getLocation();
3676+
StringRef name;
3677+
switch (kind) {
3678+
case AutoDiffAssociatedFunctionKind::JVP:
3679+
name = attr->getJVPName();
3680+
break;
3681+
case AutoDiffAssociatedFunctionKind::VJP:
3682+
name = attr->getVJPName();
3683+
break;
3684+
}
3685+
auto assocGenSig =
3686+
getAutoDiffAssociatedFunctionGenericSignature(attr, original);
3687+
auto assocFnTy = originalTy->getAutoDiffAssociatedFunctionType(
3688+
indices.parameters, indices.source, /*differentiationOrder*/ 1, kind,
3689+
module, LookUpConformanceInModule(module.getSwiftModule()), assocGenSig);
3690+
SILOptFunctionBuilder fb(getTransform());
3691+
// Create external function declaration.
3692+
auto *assocFn =
3693+
fb.createFunction(SILLinkage::PublicExternal, name, assocFnTy,
3694+
/*GenericEnv*/ nullptr, originalLoc, original->isBare(),
3695+
IsNotTransparent, original->isSerialized());
3696+
// NOTE: Setting debug scope is necessary to prevent crash in TFPartition.
3697+
assocFn->setDebugScope(new (module) SILDebugScope(originalLoc, assocFn));
3698+
return assocFn;
3699+
}
3700+
36423701
DifferentiationTask::DifferentiationTask(ADContext &context,
36433702
SILFunction *original,
36443703
SILDifferentiableAttr *&&attr,
36453704
DifferentiationInvoker invoker)
36463705
: context(context), original(original), attr(attr), invoker(invoker) {
3706+
auto &module = context.getModule();
36473707
if (attr->hasJVP()) {
3648-
jvp = lookUpOrLinkFunction(attr->getJVPName(), context.getModule());
3649-
assert(jvp);
3708+
// If attribute specifies JVP name, try to look up JVP in current module.
3709+
// Otherwise, create an external reference.
3710+
jvp = module.lookUpFunction(attr->getJVPName());
3711+
if (!jvp)
3712+
jvp = context.declareExternalAssociatedFunction(
3713+
original, attr, AutoDiffAssociatedFunctionKind::JVP);
36503714
}
36513715
if (attr->hasVJP()) {
3652-
vjp = lookUpOrLinkFunction(attr->getVJPName(), context.getModule());
3653-
assert(vjp);
3716+
// If attribute specifies VJP name, try to look up VJP in current module.
3717+
// Otherwise, create an external reference.
3718+
vjp = module.lookUpFunction(attr->getVJPName());
3719+
if (!vjp)
3720+
vjp = context.declareExternalAssociatedFunction(
3721+
original, attr, AutoDiffAssociatedFunctionKind::VJP);
36543722
}
36553723

36563724
if (!jvp)
36573725
createJVP();
36583726

36593727
if (vjp) {
3660-
// If we already have the vjp, then we don't need to synthesize anything.
3728+
// If the VJP exists, then no synthesis is needed.
36613729
primalSynthesisState = FunctionSynthesisState::NotNeeded;
36623730
adjointSynthesisState = FunctionSynthesisState::NotNeeded;
36633731
return;
@@ -3670,31 +3738,6 @@ DifferentiationTask::DifferentiationTask(ADContext &context,
36703738
createVJP();
36713739
}
36723740

3673-
// Return the expected generic signature for autodiff associated functions given
3674-
// a SILDifferentiableAttr. The expected generic signature is built from the
3675-
// original generic signature and the attribute's requirements.
3676-
static GenericSignature *
3677-
getAutoDiffAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
3678-
SILFunction *original) {
3679-
auto originalGenSig =
3680-
original->getLoweredFunctionType()->getGenericSignature();
3681-
if (!originalGenSig)
3682-
return nullptr;
3683-
GenericSignatureBuilder builder(original->getASTContext());
3684-
// Add original generic signature.
3685-
builder.addGenericSignature(originalGenSig);
3686-
// Add where clause requirements.
3687-
auto source =
3688-
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
3689-
for (auto &req : attr->getRequirements())
3690-
builder.addRequirement(req, source, original->getModule().getSwiftModule());
3691-
auto canGenericSig = std::move(builder)
3692-
.computeGenericSignature(
3693-
SourceLoc(), /*allowConcreteGenericParams=*/true)
3694-
->getCanonicalSignature();
3695-
return canGenericSig;
3696-
}
3697-
36983741
void DifferentiationTask::createEmptyPrimal() {
36993742
assert(primalSynthesisState == FunctionSynthesisState::Needed);
37003743
assert(!primalInfo);
@@ -3707,7 +3750,7 @@ void DifferentiationTask::createEmptyPrimal() {
37073750
.getIdentifier("AD__" + original->getName().str() +
37083751
"__primal_" + indices.mangle())
37093752
.str();
3710-
auto *primalGenericSig =
3753+
auto primalGenericSig =
37113754
getAutoDiffAssociatedFunctionGenericSignature(attr, original);
37123755
StructDecl *primalValueStructDecl = context.createPrimalValueStruct(this);
37133756
primalInfo = std::unique_ptr<PrimalInfo>(
@@ -3846,7 +3889,7 @@ void DifferentiationTask::createEmptyAdjoint() {
38463889
.getIdentifier("AD__" + original->getName().str() +
38473890
"__adjoint_" + getIndices().mangle())
38483891
.str();
3849-
auto *adjGenericSig =
3892+
auto adjGenericSig =
38503893
getAutoDiffAssociatedFunctionGenericSignature(attr, original);
38513894
auto *adjGenericEnv = adjGenericSig
38523895
? adjGenericSig->createGenericEnvironment()
@@ -3887,7 +3930,7 @@ void DifferentiationTask::createJVP() {
38873930
.getIdentifier("AD__" + original->getName().str() +
38883931
"__jvp_" + getIndices().mangle())
38893932
.str();
3890-
auto *jvpGenericSig =
3933+
auto jvpGenericSig =
38913934
getAutoDiffAssociatedFunctionGenericSignature(attr, original);
38923935
auto *jvpGenericEnv = jvpGenericSig
38933936
? jvpGenericSig->createGenericEnvironment()
@@ -3946,7 +3989,7 @@ void DifferentiationTask::createVJP() {
39463989
.getIdentifier("AD__" + original->getName().str() +
39473990
"__vjp_" + getIndices().mangle())
39483991
.str();
3949-
auto *vjpGenericSig =
3992+
auto vjpGenericSig =
39503993
getAutoDiffAssociatedFunctionGenericSignature(attr, original);
39513994
auto *vjpGenericEnv = vjpGenericSig
39523995
? vjpGenericSig->createGenericEnvironment()
@@ -4299,22 +4342,33 @@ void Differentiation::run() {
42994342
for (auto *adfi : autodiffInsts)
43004343
errorProcessingAutoDiffInsts |= processAutoDiffFunctionInst(adfi, context);
43014344

4345+
auto cleanUp = [&]() {
4346+
for (auto &task : context.getDifferentiationTasks())
4347+
context.clearTask(task.get());
4348+
};
4349+
43024350
// Run primal generation for newly created differentiation tasks. If any error
43034351
// occurs, back out.
43044352
PrimalGen primalGen(context);
4305-
if (primalGen.run())
4353+
if (primalGen.run()) {
4354+
cleanUp();
43064355
return;
4356+
}
43074357

43084358
// Run adjoint generation for differentiation tasks. If any error occurs, back
43094359
// out.
43104360
AdjointGen adjointGen(context);
4311-
if (adjointGen.run())
4361+
if (adjointGen.run()) {
4362+
cleanUp();
43124363
return;
4364+
}
43134365

43144366
// If there was any error that occurred during `autodiff_function` instruction
43154367
// processing, back out.
4316-
if (errorProcessingAutoDiffInsts)
4368+
if (errorProcessingAutoDiffInsts) {
4369+
cleanUp();
43174370
return;
4371+
}
43184372

43194373
LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n");
43204374
}

0 commit comments

Comments
 (0)