@@ -75,6 +75,9 @@ llvm::cl::opt<bool>
75
75
llvm::cl::opt<bool >
76
76
EnzymeRegisterReduce (" enzyme-register-reduce" , cl::init(false ), cl::Hidden,
77
77
cl::desc(" Reduce the amount of register reduce" ));
78
+ llvm::cl::opt<bool >
79
+ EnzymeSpeculatePHIs (" enzyme-speculate-phis" , cl::init(false ), cl::Hidden,
80
+ cl::desc(" Speculatively execute phi computations" ));
78
81
}
79
82
80
83
bool isPotentialLastLoopValue (Value *val, const BasicBlock *loc,
@@ -869,12 +872,20 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
869
872
870
873
unwrap_cache[blocks[i]] = unwrap_cache[oldB];
871
874
lookup_cache[blocks[i]] = lookup_cache[oldB];
875
+ auto PB = *done[std::make_pair (parent, predBlocks[i])].begin ();
876
+
877
+ if (auto inst = dyn_cast<Instruction>(
878
+ phi->getIncomingValueForBlock (PB))) {
879
+ if (inst->mayReadFromMemory () || !EnzymeSpeculatePHIs)
880
+ vals.push_back (
881
+ getOpFull (B, phi->getIncomingValueForBlock (PB), PB));
882
+ else
883
+ vals.push_back (getOpFull (
884
+ BuilderM, phi->getIncomingValueForBlock (PB), PB));
885
+ } else
886
+ vals.push_back (
887
+ getOpFull (BuilderM, phi->getIncomingValueForBlock (PB), PB));
872
888
873
- vals.push_back (getOpFull (
874
- B,
875
- phi->getIncomingValueForBlock (
876
- *done[std::make_pair (parent, predBlocks[i])].begin ()),
877
- *done[std::make_pair (parent, predBlocks[i])].begin ()));
878
889
if (!vals[i]) {
879
890
for (size_t j = 0 ; j < i; i++) {
880
891
reverseBlocks[fwd].erase (std::find (reverseBlocks[fwd].begin (),
@@ -1034,7 +1045,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1034
1045
1035
1046
if (auto inst =
1036
1047
dyn_cast<Instruction>(phi->getIncomingValueForBlock (PB))) {
1037
- if (inst->mayReadFromMemory ())
1048
+ if (inst->mayReadFromMemory () || !EnzymeSpeculatePHIs )
1038
1049
vals.push_back (getOpFull (B, phi->getIncomingValueForBlock (PB), PB));
1039
1050
else
1040
1051
vals.push_back (
0 commit comments