Skip to content

Commit a2ba998

Browse files
authored
[AutoDiff] Fix autodiff_function instruction worklist. (#24941)
Fix `autodiff_function` instruction worklist registration/invalidation. - In `VJPEmitter`, add cloned `autodiff_function` instructions in VJP to worklist. - If `autodiff_function` instructions are deleted in `foldAutoDiffFunctionExtraction`, also remove their worklist occurrences. Resolves TF-515.
1 parent 995e4a1 commit a2ba998

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3181,6 +3181,14 @@ class VJPEmitter final
31813181
recursivelyDeleteTriviallyDeadInstructions(
31823182
getOpValue(origCallee)->getDefiningInstruction());
31833183
}
3184+
3185+
void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
3186+
// Clone `autodiff_function` from original to VJP, then add the cloned
3187+
// instruction to the `autodiff_function` worklist.
3188+
SILClonerWithScopes::visitAutoDiffFunctionInst(adfi);
3189+
auto *newADFI = cast<AutoDiffFunctionInst>(getOpValue(adfi));
3190+
context.getAutoDiffFunctionInsts().push_back(newADFI);
3191+
}
31843192
};
31853193
} // end anonymous namespace
31863194

@@ -5948,7 +5956,8 @@ SILValue ADContext::promoteToDifferentiableFunction(
59485956
///
59495957
/// Folding can be disabled by the `SkipFoldingAutoDiffFunctionExtraction` flag
59505958
/// for SIL testing purposes.
5951-
static void foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
5959+
static void foldAutoDiffFunctionExtraction(
5960+
ADContext &context, AutoDiffFunctionInst *source) {
59525961
// Iterate through all `autodiff_function` instruction uses.
59535962
for (auto use : source->getUses()) {
59545963
auto *adfei = dyn_cast<AutoDiffFunctionExtractInst>(use->getUser());
@@ -5970,8 +5979,14 @@ static void foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
59705979
adfei->eraseFromParent();
59715980
}
59725981
// If the `autodiff_function` instruction has no remaining uses, erase it.
5973-
if (isInstructionTriviallyDead(source))
5974-
source->eraseFromParent();
5982+
if (!isInstructionTriviallyDead(source))
5983+
return;
5984+
source->eraseFromParent();
5985+
// Delete all worklist occurrences of `source` by setting them to nullptr.
5986+
// This is more efficient than APIs like `llvm::erase_if`.
5987+
for (auto &inst : context.getAutoDiffFunctionInsts())
5988+
if (inst == source)
5989+
inst = nullptr;
59755990
}
59765991

59775992
bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
@@ -6000,11 +6015,10 @@ bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
60006015
adfi->eraseFromParent();
60016016
// If the promoted `@differentiable` function-typed value is an
60026017
// `autodiff_function` instruction, fold `autodiff_function_extract`
6003-
// instructions.
6004-
// If `autodiff_function_extract` folding is disabled, return.
6018+
// instructions. If `autodiff_function_extract` folding is disabled, return.
60056019
if (!SkipFoldingAutoDiffFunctionExtraction)
60066020
if (auto *newADFI = dyn_cast<AutoDiffFunctionInst>(differentiableFnValue))
6007-
foldAutoDiffFunctionExtraction(newADFI);
6021+
foldAutoDiffFunctionExtraction(*this, newADFI);
60086022
transform.invalidateAnalysis(
60096023
parent, SILAnalysis::InvalidationKind::FunctionBody);
60106024
return false;

0 commit comments

Comments
 (0)