Skip to content

Commit b1009c5

Browse files
committed
Speculate
1 parent 75fee65 commit b1009c5

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ cl::opt<bool>
141141
EnzymeNameInstructions("enzyme-name-instructions", cl::init(false),
142142
cl::Hidden,
143143
cl::desc("Have enzyme name all instructions"));
144+
145+
cl::opt<bool> EnzymeSelectOpt("enzyme-select-opt", cl::init(true), cl::Hidden,
146+
cl::desc("Run Enzyme select optimization"));
144147
}
145148

146149
/// Is the use of value val as an argument of call CI potentially captured
@@ -1648,17 +1651,20 @@ void PreProcessCache::optimizeIntermediate(Function *F) {
16481651
PromotePass().run(*F, FAM);
16491652
GVN().run(*F, FAM);
16501653
SROA().run(*F, FAM);
1654+
1655+
if (EnzymeSelectOpt) {
16511656
#if LLVM_VERSION_MAJOR >= 12
1652-
SimplifyCFGOptions scfgo;
1657+
SimplifyCFGOptions scfgo;
16531658
#else
1654-
SimplifyCFGOptions scfgo(
1655-
/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false,
1656-
/*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true,
1657-
/*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
1659+
SimplifyCFGOptions scfgo(
1660+
/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false,
1661+
/*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true,
1662+
/*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
16581663
#endif
1659-
SimplifyCFGPass(scfgo).run(*F, FAM);
1660-
CorrelatedValuePropagationPass().run(*F, FAM);
1661-
SelectOptimization(F);
1664+
SimplifyCFGPass(scfgo).run(*F, FAM);
1665+
CorrelatedValuePropagationPass().run(*F, FAM);
1666+
SelectOptimization(F);
1667+
}
16621668
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);
16631669

16641670
for (Function &Impl : *F->getParent()) {

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ llvm::cl::opt<bool>
7575
llvm::cl::opt<bool>
7676
EnzymeRegisterReduce("enzyme-register-reduce", cl::init(false), cl::Hidden,
7777
cl::desc("Reduce the amount of register reduce"));
78+
llvm::cl::opt<bool>
79+
EnzymeSpeculatePHIs("enzyme-speculate-phis", cl::init(false), cl::Hidden,
80+
cl::desc("Speculatively execute phi computations"));
7881
}
7982

8083
bool isPotentialLastLoopValue(Value *val, const BasicBlock *loc,
@@ -869,12 +872,20 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
869872

870873
unwrap_cache[blocks[i]] = unwrap_cache[oldB];
871874
lookup_cache[blocks[i]] = lookup_cache[oldB];
875+
auto PB = *done[std::make_pair(parent, predBlocks[i])].begin();
876+
877+
if (auto inst = dyn_cast<Instruction>(
878+
phi->getIncomingValueForBlock(PB))) {
879+
if (inst->mayReadFromMemory() || !EnzymeSpeculatePHIs)
880+
vals.push_back(
881+
getOpFull(B, phi->getIncomingValueForBlock(PB), PB));
882+
else
883+
vals.push_back(getOpFull(
884+
BuilderM, phi->getIncomingValueForBlock(PB), PB));
885+
} else
886+
vals.push_back(
887+
getOpFull(BuilderM, phi->getIncomingValueForBlock(PB), PB));
872888

873-
vals.push_back(getOpFull(
874-
B,
875-
phi->getIncomingValueForBlock(
876-
*done[std::make_pair(parent, predBlocks[i])].begin()),
877-
*done[std::make_pair(parent, predBlocks[i])].begin()));
878889
if (!vals[i]) {
879890
for (size_t j = 0; j < i; i++) {
880891
reverseBlocks[fwd].erase(std::find(reverseBlocks[fwd].begin(),
@@ -1034,7 +1045,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10341045

10351046
if (auto inst =
10361047
dyn_cast<Instruction>(phi->getIncomingValueForBlock(PB))) {
1037-
if (inst->mayReadFromMemory())
1048+
if (inst->mayReadFromMemory() || !EnzymeSpeculatePHIs)
10381049
vals.push_back(getOpFull(B, phi->getIncomingValueForBlock(PB), PB));
10391050
else
10401051
vals.push_back(

0 commit comments

Comments
 (0)