@@ -3181,6 +3181,14 @@ class VJPEmitter final
3181
3181
recursivelyDeleteTriviallyDeadInstructions (
3182
3182
getOpValue (origCallee)->getDefiningInstruction ());
3183
3183
}
3184
+
3185
+ void visitAutoDiffFunctionInst (AutoDiffFunctionInst *adfi) {
3186
+ // Clone `autodiff_function` from original to VJP, then add the cloned
3187
+ // instruction to the `autodiff_function` worklist.
3188
+ SILClonerWithScopes::visitAutoDiffFunctionInst (adfi);
3189
+ auto *newADFI = cast<AutoDiffFunctionInst>(getOpValue (adfi));
3190
+ context.getAutoDiffFunctionInsts ().push_back (newADFI);
3191
+ }
3184
3192
};
3185
3193
} // end anonymous namespace
3186
3194
@@ -5948,7 +5956,8 @@ SILValue ADContext::promoteToDifferentiableFunction(
5948
5956
// /
5949
5957
// / Folding can be disabled by the `SkipFoldingAutoDiffFunctionExtraction` flag
5950
5958
// / for SIL testing purposes.
5951
- static void foldAutoDiffFunctionExtraction (AutoDiffFunctionInst *source) {
5959
+ static void foldAutoDiffFunctionExtraction (
5960
+ ADContext &context, AutoDiffFunctionInst *source) {
5952
5961
// Iterate through all `autodiff_function` instruction uses.
5953
5962
for (auto use : source->getUses ()) {
5954
5963
auto *adfei = dyn_cast<AutoDiffFunctionExtractInst>(use->getUser ());
@@ -5970,8 +5979,14 @@ static void foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
5970
5979
adfei->eraseFromParent ();
5971
5980
}
5972
5981
// If the `autodiff_function` instruction has no remaining uses, erase it.
5973
- if (isInstructionTriviallyDead (source))
5974
- source->eraseFromParent ();
5982
+ if (!isInstructionTriviallyDead (source))
5983
+ return ;
5984
+ source->eraseFromParent ();
5985
+ // Delete all worklist occurrences of `source` by setting them to nullptr.
5986
+ // This is more efficient than APIs like `llvm::erase_if`.
5987
+ for (auto &inst : context.getAutoDiffFunctionInsts ())
5988
+ if (inst == source)
5989
+ inst = nullptr ;
5975
5990
}
5976
5991
5977
5992
bool ADContext::processAutoDiffFunctionInst (AutoDiffFunctionInst *adfi) {
@@ -6000,11 +6015,10 @@ bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
6000
6015
adfi->eraseFromParent ();
6001
6016
// If the promoted `@differentiable` function-typed value is an
6002
6017
// `autodiff_function` instruction, fold `autodiff_function_extract`
6003
- // instructions.
6004
- // If `autodiff_function_extract` folding is disabled, return.
6018
+ // instructions. If `autodiff_function_extract` folding is disabled, return.
6005
6019
if (!SkipFoldingAutoDiffFunctionExtraction)
6006
6020
if (auto *newADFI = dyn_cast<AutoDiffFunctionInst>(differentiableFnValue))
6007
- foldAutoDiffFunctionExtraction (newADFI);
6021
+ foldAutoDiffFunctionExtraction (* this , newADFI);
6008
6022
transform.invalidateAnalysis (
6009
6023
parent, SILAnalysis::InvalidationKind::FunctionBody);
6010
6024
return false ;
0 commit comments