33
33
#include " swift/AST/Module.h"
34
34
#include " swift/AST/ParameterList.h"
35
35
#include " swift/AST/SubstitutionMap.h"
36
- #include " swift/Serialization/SerializedSILLoader.h"
37
36
#include " swift/SIL/FormalLinkage.h"
38
37
#include " swift/SIL/LoopInfo.h"
39
38
#include " swift/SIL/SILBuilder.h"
@@ -856,10 +855,14 @@ class ADContext {
856
855
void clearTask (DifferentiationTask *task) {
857
856
LLVM_DEBUG (getADDebugStream () << " Clearing differentiation task for "
858
857
<< task->original ->getName () << ' \n ' );
859
- transform.notifyWillDeleteFunction (task->primal );
860
- module .eraseFunction (task->primal );
861
- transform.notifyWillDeleteFunction (task->adjoint );
862
- module .eraseFunction (task->adjoint );
858
+ if (task->primal ) {
859
+ transform.notifyWillDeleteFunction (task->primal );
860
+ module .eraseFunction (task->primal );
861
+ }
862
+ if (task->adjoint ) {
863
+ transform.notifyWillDeleteFunction (task->adjoint );
864
+ module .eraseFunction (task->adjoint );
865
+ }
863
866
transform.notifyWillDeleteFunction (task->jvp );
864
867
module .eraseFunction (task->jvp );
865
868
transform.notifyWillDeleteFunction (task->vjp );
@@ -980,6 +983,14 @@ class ADContext {
980
983
return differentiationTasks.back ().get ();
981
984
}
982
985
986
+ // / Declare an external reference to an associated function of `original`,
987
+ // / given a `[differentiable]` attribute of `original` and the associated
988
+ // / function kind.
989
+ SILFunction *
990
+ declareExternalAssociatedFunction (SILFunction *original,
991
+ SILDifferentiableAttr *attr,
992
+ AutoDiffAssociatedFunctionKind kind);
993
+
983
994
template <typename ... T, typename ... U>
984
995
InFlightDiagnostic diagnose (SourceLoc loc, Diag<T...> diag,
985
996
U &&... args) const {
@@ -1006,11 +1017,7 @@ class ADContext {
1006
1017
1007
1018
ADContext::ADContext (SILModuleTransform &transform)
1008
1019
: transform(transform), module(*transform.getModule()),
1009
- passManager(*transform.getPassManager()) {
1010
- // Note: `getSILLoader` performs important initialization and is necessary to
1011
- // prevent test failures related to `lookUpFunctionInWitnessTable`.
1012
- (void )module .getSILLoader ();
1013
- }
1020
+ passManager(*transform.getPassManager()) {}
1014
1021
1015
1022
void ADContext::emitNondifferentiabilityError (SILValue value,
1016
1023
const DifferentiationTask *task,
@@ -2261,7 +2268,7 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2261
2268
} // end anonymous namespace
2262
2269
2263
2270
bool PrimalGen::performSynthesis (FunctionSynthesisItem item) {
2264
- LLVM_DEBUG (getADDebugStream () << " Performing primal synthesis for original"
2271
+ LLVM_DEBUG (getADDebugStream () << " Performing primal synthesis for original "
2265
2272
<< item.original ->getName () << " and its corresponding primal "
2266
2273
<< item.target ->getName () << ' \n ' );
2267
2274
// FIXME: If the original function has multiple basic blocks, bail out since
@@ -2314,8 +2321,8 @@ bool PrimalGen::run() {
2314
2321
auto synthesis = worklist.back ();
2315
2322
worklist.pop_back ();
2316
2323
if (performSynthesis (synthesis)) {
2317
- context.clearTask (synthesis.task );
2318
2324
errorOccurred = true ;
2325
+ continue ;
2319
2326
}
2320
2327
synthesis.task ->getPrimalInfo ()->computePrimalValueStructType ();
2321
2328
synthesis.task ->setPrimalSynthesisState (FunctionSynthesisState::Done);
@@ -2373,8 +2380,8 @@ bool AdjointGen::run() {
2373
2380
auto synthesis = worklist.back ();
2374
2381
worklist.pop_back ();
2375
2382
if (performSynthesis (synthesis)) {
2376
- context.clearTask (synthesis.task );
2377
2383
errorOccurred = true ;
2384
+ continue ;
2378
2385
}
2379
2386
synthesis.task ->setAdjointSynthesisState (FunctionSynthesisState::Done);
2380
2387
}
@@ -3301,8 +3308,6 @@ void AdjointEmitter::materializeZeroIndirect(CanType type,
3301
3308
// %wm = witness_method ...
3302
3309
auto *getter = builder.createWitnessMethod (loc, type, confRef,
3303
3310
accessorDeclRef, methodType);
3304
- // Ensure that the witness table is linked.
3305
- (void )getModule ().lookUpFunctionInWitnessTable (confRef, accessorDeclRef);
3306
3311
// %metatype = metatype $T
3307
3312
auto metatypeType = CanMetatypeType::get (type, MetatypeRepresentation::Thick);
3308
3313
auto metatype = builder.createMetatype (
@@ -3594,8 +3599,6 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
3594
3599
// %0 = witness_method @+
3595
3600
auto witnessMethod = builder.createWitnessMethod (loc, adjointASTTy,
3596
3601
confRef, declRef, silFnTy);
3597
- // Ensure the witness method is linked.
3598
- getModule ().lookUpFunctionInWitnessTable (confRef, declRef);
3599
3602
auto subMap =
3600
3603
SubstitutionMap::getProtocolSubstitutions (proto, adjointASTTy, confRef);
3601
3604
// %1 = metatype $T.Type
@@ -3623,7 +3626,7 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
3623
3626
}
3624
3627
3625
3628
bool AdjointGen::performSynthesis (FunctionSynthesisItem item) {
3626
- LLVM_DEBUG (getADDebugStream () << " Performing adjoint synthesis for original"
3629
+ LLVM_DEBUG (getADDebugStream () << " Performing adjoint synthesis for original "
3627
3630
<< item.original ->getName () << " and its corresponding adjoint "
3628
3631
<< item.target ->getName () << ' \n ' );
3629
3632
auto &passManager = context.getPassManager ();
@@ -3639,25 +3642,90 @@ bool AdjointGen::performSynthesis(FunctionSynthesisItem item) {
3639
3642
// DifferentiationTask
3640
3643
// ===----------------------------------------------------------------------===//
3641
3644
3645
+ // Return the expected generic signature for autodiff associated functions given
3646
+ // a SILDifferentiableAttr. The expected generic signature is built from the
3647
+ // original generic signature and the attribute's requirements.
3648
+ static CanGenericSignature
3649
+ getAutoDiffAssociatedFunctionGenericSignature (SILDifferentiableAttr *attr,
3650
+ SILFunction *original) {
3651
+ auto originalGenSig =
3652
+ original->getLoweredFunctionType ()->getGenericSignature ();
3653
+ if (!originalGenSig)
3654
+ return nullptr ;
3655
+ GenericSignatureBuilder builder (original->getASTContext ());
3656
+ // Add original generic signature.
3657
+ builder.addGenericSignature (originalGenSig);
3658
+ // Add where clause requirements.
3659
+ auto source =
3660
+ GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
3661
+ for (auto &req : attr->getRequirements ())
3662
+ builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
3663
+ return std::move (builder)
3664
+ .computeGenericSignature (SourceLoc (), /* allowConcreteGenericParams=*/ true )
3665
+ ->getCanonicalSignature ();
3666
+ }
3667
+
3668
+ SILFunction *
3669
+ ADContext::declareExternalAssociatedFunction (
3670
+ SILFunction *original, SILDifferentiableAttr *attr,
3671
+ AutoDiffAssociatedFunctionKind kind) {
3672
+ auto &module = getModule ();
3673
+ auto &indices = attr->getIndices ();
3674
+ auto originalTy = original->getLoweredFunctionType ();
3675
+ auto originalLoc = original->getLocation ();
3676
+ StringRef name;
3677
+ switch (kind) {
3678
+ case AutoDiffAssociatedFunctionKind::JVP:
3679
+ name = attr->getJVPName ();
3680
+ break ;
3681
+ case AutoDiffAssociatedFunctionKind::VJP:
3682
+ name = attr->getVJPName ();
3683
+ break ;
3684
+ }
3685
+ auto assocGenSig =
3686
+ getAutoDiffAssociatedFunctionGenericSignature (attr, original);
3687
+ auto assocFnTy = originalTy->getAutoDiffAssociatedFunctionType (
3688
+ indices.parameters , indices.source , /* differentiationOrder*/ 1 , kind,
3689
+ module , LookUpConformanceInModule (module .getSwiftModule ()), assocGenSig);
3690
+ SILOptFunctionBuilder fb (getTransform ());
3691
+ // Create external function declaration.
3692
+ auto *assocFn =
3693
+ fb.createFunction (SILLinkage::PublicExternal, name, assocFnTy,
3694
+ /* GenericEnv*/ nullptr , originalLoc, original->isBare (),
3695
+ IsNotTransparent, original->isSerialized ());
3696
+ // NOTE: Setting debug scope is necessary to prevent crash in TFPartition.
3697
+ assocFn->setDebugScope (new (module ) SILDebugScope (originalLoc, assocFn));
3698
+ return assocFn;
3699
+ }
3700
+
3642
3701
DifferentiationTask::DifferentiationTask (ADContext &context,
3643
3702
SILFunction *original,
3644
3703
SILDifferentiableAttr *&&attr,
3645
3704
DifferentiationInvoker invoker)
3646
3705
: context(context), original(original), attr(attr), invoker(invoker) {
3706
+ auto &module = context.getModule ();
3647
3707
if (attr->hasJVP ()) {
3648
- jvp = lookUpOrLinkFunction (attr->getJVPName (), context.getModule ());
3649
- assert (jvp);
3708
+ // If attribute specifies JVP name, try to look up JVP in current module.
3709
+ // Otherwise, create an external reference.
3710
+ jvp = module .lookUpFunction (attr->getJVPName ());
3711
+ if (!jvp)
3712
+ jvp = context.declareExternalAssociatedFunction (
3713
+ original, attr, AutoDiffAssociatedFunctionKind::JVP);
3650
3714
}
3651
3715
if (attr->hasVJP ()) {
3652
- vjp = lookUpOrLinkFunction (attr->getVJPName (), context.getModule ());
3653
- assert (vjp);
3716
+ // If attribute specifies VJP name, try to look up VJP in current module.
3717
+ // Otherwise, create an external reference.
3718
+ vjp = module .lookUpFunction (attr->getVJPName ());
3719
+ if (!vjp)
3720
+ vjp = context.declareExternalAssociatedFunction (
3721
+ original, attr, AutoDiffAssociatedFunctionKind::VJP);
3654
3722
}
3655
3723
3656
3724
if (!jvp)
3657
3725
createJVP ();
3658
3726
3659
3727
if (vjp) {
3660
- // If we already have the vjp , then we don't need to synthesize anything .
3728
+ // If the VJP exists , then no synthesis is needed .
3661
3729
primalSynthesisState = FunctionSynthesisState::NotNeeded;
3662
3730
adjointSynthesisState = FunctionSynthesisState::NotNeeded;
3663
3731
return ;
@@ -3670,31 +3738,6 @@ DifferentiationTask::DifferentiationTask(ADContext &context,
3670
3738
createVJP ();
3671
3739
}
3672
3740
3673
- // Return the expected generic signature for autodiff associated functions given
3674
- // a SILDifferentiableAttr. The expected generic signature is built from the
3675
- // original generic signature and the attribute's requirements.
3676
- static GenericSignature *
3677
- getAutoDiffAssociatedFunctionGenericSignature (SILDifferentiableAttr *attr,
3678
- SILFunction *original) {
3679
- auto originalGenSig =
3680
- original->getLoweredFunctionType ()->getGenericSignature ();
3681
- if (!originalGenSig)
3682
- return nullptr ;
3683
- GenericSignatureBuilder builder (original->getASTContext ());
3684
- // Add original generic signature.
3685
- builder.addGenericSignature (originalGenSig);
3686
- // Add where clause requirements.
3687
- auto source =
3688
- GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
3689
- for (auto &req : attr->getRequirements ())
3690
- builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
3691
- auto canGenericSig = std::move (builder)
3692
- .computeGenericSignature (
3693
- SourceLoc (), /* allowConcreteGenericParams=*/ true )
3694
- ->getCanonicalSignature ();
3695
- return canGenericSig;
3696
- }
3697
-
3698
3741
void DifferentiationTask::createEmptyPrimal () {
3699
3742
assert (primalSynthesisState == FunctionSynthesisState::Needed);
3700
3743
assert (!primalInfo);
@@ -3707,7 +3750,7 @@ void DifferentiationTask::createEmptyPrimal() {
3707
3750
.getIdentifier (" AD__" + original->getName ().str () +
3708
3751
" __primal_" + indices.mangle ())
3709
3752
.str ();
3710
- auto * primalGenericSig =
3753
+ auto primalGenericSig =
3711
3754
getAutoDiffAssociatedFunctionGenericSignature (attr, original);
3712
3755
StructDecl *primalValueStructDecl = context.createPrimalValueStruct (this );
3713
3756
primalInfo = std::unique_ptr<PrimalInfo>(
@@ -3846,7 +3889,7 @@ void DifferentiationTask::createEmptyAdjoint() {
3846
3889
.getIdentifier (" AD__" + original->getName ().str () +
3847
3890
" __adjoint_" + getIndices ().mangle ())
3848
3891
.str ();
3849
- auto * adjGenericSig =
3892
+ auto adjGenericSig =
3850
3893
getAutoDiffAssociatedFunctionGenericSignature (attr, original);
3851
3894
auto *adjGenericEnv = adjGenericSig
3852
3895
? adjGenericSig->createGenericEnvironment ()
@@ -3887,7 +3930,7 @@ void DifferentiationTask::createJVP() {
3887
3930
.getIdentifier (" AD__" + original->getName ().str () +
3888
3931
" __jvp_" + getIndices ().mangle ())
3889
3932
.str ();
3890
- auto * jvpGenericSig =
3933
+ auto jvpGenericSig =
3891
3934
getAutoDiffAssociatedFunctionGenericSignature (attr, original);
3892
3935
auto *jvpGenericEnv = jvpGenericSig
3893
3936
? jvpGenericSig->createGenericEnvironment ()
@@ -3946,7 +3989,7 @@ void DifferentiationTask::createVJP() {
3946
3989
.getIdentifier (" AD__" + original->getName ().str () +
3947
3990
" __vjp_" + getIndices ().mangle ())
3948
3991
.str ();
3949
- auto * vjpGenericSig =
3992
+ auto vjpGenericSig =
3950
3993
getAutoDiffAssociatedFunctionGenericSignature (attr, original);
3951
3994
auto *vjpGenericEnv = vjpGenericSig
3952
3995
? vjpGenericSig->createGenericEnvironment ()
@@ -4299,22 +4342,33 @@ void Differentiation::run() {
4299
4342
for (auto *adfi : autodiffInsts)
4300
4343
errorProcessingAutoDiffInsts |= processAutoDiffFunctionInst (adfi, context);
4301
4344
4345
+ auto cleanUp = [&]() {
4346
+ for (auto &task : context.getDifferentiationTasks ())
4347
+ context.clearTask (task.get ());
4348
+ };
4349
+
4302
4350
// Run primal generation for newly created differentiation tasks. If any error
4303
4351
// occurs, back out.
4304
4352
PrimalGen primalGen (context);
4305
- if (primalGen.run ())
4353
+ if (primalGen.run ()) {
4354
+ cleanUp ();
4306
4355
return ;
4356
+ }
4307
4357
4308
4358
// Run adjoint generation for differentiation tasks. If any error occurs, back
4309
4359
// out.
4310
4360
AdjointGen adjointGen (context);
4311
- if (adjointGen.run ())
4361
+ if (adjointGen.run ()) {
4362
+ cleanUp ();
4312
4363
return ;
4364
+ }
4313
4365
4314
4366
// If there was any error that occurred during `autodiff_function` instruction
4315
4367
// processing, back out.
4316
- if (errorProcessingAutoDiffInsts)
4368
+ if (errorProcessingAutoDiffInsts) {
4369
+ cleanUp ();
4317
4370
return ;
4371
+ }
4318
4372
4319
4373
LLVM_DEBUG (getADDebugStream () << " All differentiation finished\n " );
4320
4374
}
0 commit comments