Skip to content

Commit 2b43542

Browse files
committed
Add autodiff_function_extract folding optimization.
Fold `autodiff_function_extract` users of `autodiff_function` instructions, directly replacing them with operands of the `autodiff_function` instruction. If the `autodiff_function` instruction has no non-`autodiff_function_extract` users, delete the instruction itself after folding. The `differentiation-skip-folding-autodiff-function-extraction` flag disables folding for SIL testing purposes.
1 parent 646096e commit 2b43542

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ using llvm::SmallDenseMap;
5656
using llvm::SmallDenseSet;
5757
using llvm::SmallSet;
5858

59+
// This flag is used to disable `autodiff_function_extract` instruction folding
60+
// for SIL testing purposes.
61+
static llvm::cl::opt<bool> SkipFoldingAutoDiffFunctionExtraction(
62+
"differentiation-skip-folding-autodiff-function-extraction",
63+
llvm::cl::init(false));
64+
5965
//===----------------------------------------------------------------------===//
6066
// Helpers
6167
//===----------------------------------------------------------------------===//
@@ -5965,6 +5971,44 @@ SILValue ADContext::promoteToDifferentiableFunction(
59655971
return adfi;
59665972
}
59675973

5974+
/// Fold `autodiff_function_extract` users of the given `autodiff_function`
5975+
/// instruction, directly replacing them with `autodiff_function` instruction
5976+
/// operands. If the `autodiff_function` instruction has no
5977+
/// non-`autodiff_function_extract` users, delete the instruction itself after
5978+
/// folding.
5979+
///
5980+
/// Folding can be disabled by the `SkipFoldingAutoDiffFunctionExtraction` flag
5981+
/// for SIL testing purposes.
5982+
static void foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
5983+
bool hasOnlyAutoDiffFunctionExtractUsers = true;
5984+
// Iterate through all `autodiff_function` instruction uses.
5985+
for (auto use : source->getUses()) {
5986+
auto *adfei = dyn_cast<AutoDiffFunctionExtractInst>(use->getUser());
5987+
// If user is not an `autodiff_function_extract` instruction, set flag to
5988+
// false.
5989+
if (!adfei) {
5990+
hasOnlyAutoDiffFunctionExtractUsers = false;
5991+
continue;
5992+
}
5993+
// Fold original function extractors.
5994+
if (adfei->getExtractee() == AutoDiffFunctionExtractee::Original) {
5995+
auto originalFnValue = source->getOriginalFunction();
5996+
adfei->replaceAllUsesWith(originalFnValue);
5997+
adfei->eraseFromParent();
5998+
continue;
5999+
}
6000+
// Fold associated function extractors.
6001+
auto assocFnValue = source->getAssociatedFunction(
6002+
adfei->getDifferentiationOrder(), adfei->getAssociatedFunctionKind());
6003+
adfei->replaceAllUsesWith(assocFnValue);
6004+
adfei->eraseFromParent();
6005+
}
6006+
// If all users are `autodiff_function_extract` instructions, erase the
6007+
// `autodiff_function` instruction itself.
6008+
if (hasOnlyAutoDiffFunctionExtractUsers)
6009+
source->eraseFromParent();
6010+
}
6011+
59686012
bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
59696013
if (adfi->getNumAssociatedFunctions() ==
59706014
autodiff::getNumAutoDiffAssociatedFunctions(
@@ -5995,6 +6039,13 @@ bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
59956039
// Replace all uses of `adfi`.
59966040
adfi->replaceAllUsesWith(differentiableFnValue);
59976041
adfi->eraseFromParent();
6042+
// If the promoted `@differentiable` function-typed value is an
6043+
// `autodiff_function` instruction, fold `autodiff_function_extract`
6044+
// instructions.
6045+
// If `autodiff_function_extract` folding is disabled, return.
6046+
if (!SkipFoldingAutoDiffFunctionExtraction)
6047+
if (auto *newADFI = dyn_cast<AutoDiffFunctionInst>(differentiableFnValue))
6048+
foldAutoDiffFunctionExtraction(newADFI);
59986049
transform.invalidateAnalysis(
59996050
parent, SILAnalysis::InvalidationKind::FunctionBody);
60006051
return false;

test/AutoDiff/refcounting.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -Xllvm -differentiation-skip-folding-autodiff-function-extraction %s | %FileCheck %s
22

33
public class NonTrivialStuff : Equatable {
44
public init() {}

test/AutoDiff/witness_method_autodiff.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt -differentiation %s | %FileCheck %s
1+
// RUN: %target-sil-opt -differentiation -differentiation-skip-folding-autodiff-function-extraction %s | %FileCheck %s
22

33
sil_stage raw
44

test/AutoDiff/witness_table_silgen.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-skip-folding-autodiff-function-extraction %s | %FileCheck %s
22

33
protocol Proto : Differentiable {
44
@differentiable(wrt: (x, y))

0 commit comments

Comments
 (0)