@@ -56,6 +56,12 @@ using llvm::SmallDenseMap;
56
56
using llvm::SmallDenseSet;
57
57
using llvm::SmallSet;
58
58
59
+ // This flag is used to disable `autodiff_function_extract` instruction folding
60
+ // for SIL testing purposes.
61
+ static llvm::cl::opt<bool > SkipFoldingAutoDiffFunctionExtraction (
62
+ " differentiation-skip-folding-autodiff-function-extraction" ,
63
+ llvm::cl::init (false ));
64
+
59
65
// ===----------------------------------------------------------------------===//
60
66
// Helpers
61
67
// ===----------------------------------------------------------------------===//
@@ -5965,6 +5971,44 @@ SILValue ADContext::promoteToDifferentiableFunction(
5965
5971
return adfi;
5966
5972
}
5967
5973
5974
+ // / Fold `autodiff_function_extract` users of the given `autodiff_function`
5975
+ // / instruction, directly replacing them with `autodiff_function` instruction
5976
+ // / operands. If the `autodiff_function` instruction has no
5977
+ // / non-`autodiff_function_extract` users, delete the instruction itself after
5978
+ // / folding.
5979
+ // /
5980
+ // / Folding can be disabled by the `SkipFoldingAutoDiffFunctionExtraction` flag
5981
+ // / for SIL testing purposes.
5982
+ static void foldAutoDiffFunctionExtraction (AutoDiffFunctionInst *source) {
5983
+ bool hasOnlyAutoDiffFunctionExtractUsers = true ;
5984
+ // Iterate through all `autodiff_function` instruction uses.
5985
+ for (auto use : source->getUses ()) {
5986
+ auto *adfei = dyn_cast<AutoDiffFunctionExtractInst>(use->getUser ());
5987
+ // If user is not an `autodiff_function_extract` instruction, set flag to
5988
+ // false.
5989
+ if (!adfei) {
5990
+ hasOnlyAutoDiffFunctionExtractUsers = false ;
5991
+ continue ;
5992
+ }
5993
+ // Fold original function extractors.
5994
+ if (adfei->getExtractee () == AutoDiffFunctionExtractee::Original) {
5995
+ auto originalFnValue = source->getOriginalFunction ();
5996
+ adfei->replaceAllUsesWith (originalFnValue);
5997
+ adfei->eraseFromParent ();
5998
+ continue ;
5999
+ }
6000
+ // Fold associated function extractors.
6001
+ auto assocFnValue = source->getAssociatedFunction (
6002
+ adfei->getDifferentiationOrder (), adfei->getAssociatedFunctionKind ());
6003
+ adfei->replaceAllUsesWith (assocFnValue);
6004
+ adfei->eraseFromParent ();
6005
+ }
6006
+ // If all users are `autodiff_function_extract` instructions, erase the
6007
+ // `autodiff_function` instruction itself.
6008
+ if (hasOnlyAutoDiffFunctionExtractUsers)
6009
+ source->eraseFromParent ();
6010
+ }
6011
+
5968
6012
bool ADContext::processAutoDiffFunctionInst (AutoDiffFunctionInst *adfi) {
5969
6013
if (adfi->getNumAssociatedFunctions () ==
5970
6014
autodiff::getNumAutoDiffAssociatedFunctions (
@@ -5995,6 +6039,13 @@ bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
5995
6039
// Replace all uses of `adfi`.
5996
6040
adfi->replaceAllUsesWith (differentiableFnValue);
5997
6041
adfi->eraseFromParent ();
6042
+ // If the promoted `@differentiable` function-typed value is an
6043
+ // `autodiff_function` instruction, fold `autodiff_function_extract`
6044
+ // instructions.
6045
+ // If `autodiff_function_extract` folding is disabled, return.
6046
+ if (!SkipFoldingAutoDiffFunctionExtraction)
6047
+ if (auto *newADFI = dyn_cast<AutoDiffFunctionInst>(differentiableFnValue))
6048
+ foldAutoDiffFunctionExtraction (newADFI);
5998
6049
transform.invalidateAnalysis (
5999
6050
parent, SILAnalysis::InvalidationKind::FunctionBody);
6000
6051
return false ;
0 commit comments