Skip to content

Commit ad5a84c

Browse files
committed
[LoopPred/WC] Use a dominating widenable condition to remove analyze loop exits
This implements a version of the predicateLoopExits transform from IndVarSimplify extended to exploit widenable conditions - and thus be much wider in scope of legality. The code structure ends up being almost entirely different, so I chose to duplicate this into the LoopPredication pass instead of trying to reuse the code in the IndVars. The core notions of the transform are as follows: If we have a widenable condition which controls entry into the loop, we're allowed to widen it arbitrarily. Given that, it's simply a *profitability* question as to what conditions to fold into the widenable branch. To avoid pass ordering issues, we want to avoid widening cases that would otherwise be dischargeable. Or... widen in a form which can still be discharged. Thus, we phrase the transform as selecting one analyzeable exit from the set of analyzeable exits to keep. This avoids creating pass ordering complexities. Since none of the above proves that we actually exit through our analyzeable exits - we might exit through something else entirely - we limit ourselves to cases where a) the latch is analyzeable and b) the latch is predicted taken, and c) the exit being removed is statically cold. Differential Revision: https://reviews.llvm.org/D69830
1 parent e15b26f commit ad5a84c

File tree

2 files changed

+964
-9
lines changed

2 files changed

+964
-9
lines changed

llvm/lib/Transforms/Scalar/LoopPredication.cpp

Lines changed: 206 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ struct LoopICmp {
250250

251251
class LoopPredication {
252252
AliasAnalysis *AA;
253+
DominatorTree *DT;
253254
ScalarEvolution *SE;
255+
LoopInfo *LI;
254256
BranchProbabilityInfo *BPI;
255257

256258
Loop *L;
@@ -302,10 +304,13 @@ class LoopPredication {
302304
// within the loop. We identify such unprofitable loops through BPI.
303305
bool isLoopProfitableToPredicate();
304306

307+
bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
308+
305309
public:
306-
LoopPredication(AliasAnalysis *AA, ScalarEvolution *SE,
310+
LoopPredication(AliasAnalysis *AA, DominatorTree *DT,
311+
ScalarEvolution *SE, LoopInfo *LI,
307312
BranchProbabilityInfo *BPI)
308-
: AA(AA), SE(SE), BPI(BPI){};
313+
: AA(AA), DT(DT), SE(SE), LI(LI), BPI(BPI) {};
309314
bool runOnLoop(Loop *L);
310315
};
311316

@@ -325,10 +330,12 @@ class LoopPredicationLegacyPass : public LoopPass {
325330
if (skipLoop(L))
326331
return false;
327332
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
333+
auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
334+
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
328335
BranchProbabilityInfo &BPI =
329336
getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
330337
auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
331-
LoopPredication LP(AA, SE, &BPI);
338+
LoopPredication LP(AA, DT, SE, LI, &BPI);
332339
return LP.runOnLoop(L);
333340
}
334341
};
@@ -354,7 +361,7 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
354361
AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
355362
Function *F = L.getHeader()->getParent();
356363
auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
357-
LoopPredication LP(&AR.AA, &AR.SE, BPI);
364+
LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, BPI);
358365
if (!LP.runOnLoop(&L))
359366
return PreservedAnalyses::all();
360367

@@ -955,6 +962,200 @@ bool LoopPredication::isLoopProfitableToPredicate() {
955962
return true;
956963
}
957964

965+
/// If we can (cheaply) find a widenable branch which controls entry into the
966+
/// loop, return it.
967+
static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) {
968+
// Walk back through any unconditional executed blocks and see if we can find
969+
// a widenable condition which seems to control execution of this loop. Note
970+
// that we predict that maythrow calls are likely untaken and thus that it's
971+
// profitable to widen a branch before a maythrow call with a condition
972+
// afterwards even though that may cause the slow path to run in a case where
973+
// it wouldn't have otherwise.
974+
BasicBlock *BB = L->getLoopPreheader();
975+
if (!BB)
976+
return nullptr;
977+
do {
978+
if (BasicBlock *Pred = BB->getSinglePredecessor())
979+
if (BB == Pred->getSingleSuccessor()) {
980+
BB = Pred;
981+
continue;
982+
}
983+
break;
984+
} while (true);
985+
986+
if (BasicBlock *Pred = BB->getSinglePredecessor()) {
987+
auto *Term = Pred->getTerminator();
988+
989+
Value *Cond, *WC;
990+
BasicBlock *IfTrueBB, *IfFalseBB;
991+
if (parseWidenableBranch(Term, Cond, WC, IfTrueBB, IfFalseBB) &&
992+
IfTrueBB == BB)
993+
return cast<BranchInst>(Term);
994+
}
995+
return nullptr;
996+
}
997+
998+
/// Return the minimum of all analyzeable exit counts. This is an upper bound
999+
/// on the actual exit count. If there are not at least two analyzeable exits,
1000+
/// returns SCEVCouldNotCompute.
1001+
static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE,
1002+
DominatorTree &DT,
1003+
Loop *L) {
1004+
SmallVector<BasicBlock *, 16> ExitingBlocks;
1005+
L->getExitingBlocks(ExitingBlocks);
1006+
1007+
SmallVector<const SCEV *, 4> ExitCounts;
1008+
for (BasicBlock *ExitingBB : ExitingBlocks) {
1009+
const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
1010+
if (isa<SCEVCouldNotCompute>(ExitCount))
1011+
continue;
1012+
assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
1013+
"We should only have known counts for exiting blocks that "
1014+
"dominate latch!");
1015+
ExitCounts.push_back(ExitCount);
1016+
}
1017+
if (ExitCounts.size() < 2)
1018+
return SE.getCouldNotCompute();
1019+
return SE.getUMinFromMismatchedTypes(ExitCounts);
1020+
}
1021+
1022+
/// This implements an analogous, but entirely distinct transform from the main
1023+
/// loop predication transform. This one is phrased in terms of using a
1024+
/// widenable branch *outside* the loop to allow us to simplify loop exits in a
1025+
/// following loop. This is close in spirit to the IndVarSimplify transform
1026+
/// of the same name, but is materially different widening loosens legality
1027+
/// sharply.
1028+
bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1029+
// The transformation performed here aims to widen a widenable condition
1030+
// above the loop such that all analyzeable exit leading to deopt are dead.
1031+
// It assumes that the latch is the dominant exit for profitability and that
1032+
// exits branching to deoptimizing blocks are rarely taken. It relies on the
1033+
// semantics of widenable expressions for legality. (i.e. being able to fall
1034+
// down the widenable path spuriously allows us to ignore exit order,
1035+
// unanalyzeable exits, side effects, exceptional exits, and other challenges
1036+
// which restrict the applicability of the non-WC based version of this
1037+
// transform in IndVarSimplify.)
1038+
//
1039+
// NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
1040+
// imply flags on the expression being hoisted and inserting new uses (flags
1041+
// are only correct for current uses). The result is that we may be
1042+
// inserting a branch on the value which can be either poison or undef. In
1043+
// this case, the branch can legally go either way; we just need to avoid
1044+
// introducing UB. This is achieved through the use of the freeze
1045+
// instruction.
1046+
1047+
SmallVector<BasicBlock *, 16> ExitingBlocks;
1048+
L->getExitingBlocks(ExitingBlocks);
1049+
1050+
if (ExitingBlocks.empty())
1051+
return false; // Nothing to do.
1052+
1053+
auto *Latch = L->getLoopLatch();
1054+
if (!Latch)
1055+
return false;
1056+
1057+
auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
1058+
if (!WidenableBR)
1059+
return false;
1060+
1061+
const SCEV *LatchEC = SE->getExitCount(L, Latch);
1062+
if (isa<SCEVCouldNotCompute>(LatchEC))
1063+
return false; // profitability - want hot exit in analyzeable set
1064+
1065+
// The use of umin(all analyzeable exits) instead of latch is subtle, but
1066+
// important for profitability. We may have a loop which hasn't been fully
1067+
// canonicalized just yet. If the exit we chose to widen is provably never
1068+
// taken, we want the widened form to *also* be provably never taken. We
1069+
// can't guarantee this as a current unanalyzeable exit may later become
1070+
// analyzeable, but we can at least avoid the obvious cases.
1071+
const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
1072+
if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
1073+
!SE->isLoopInvariant(MinEC, L) ||
1074+
!isSafeToExpandAt(MinEC, WidenableBR, *SE))
1075+
return false;
1076+
1077+
// Subtlety: We need to avoid inserting additional uses of the WC. We know
1078+
// that it can only have one transitive use at the moment, and thus moving
1079+
// that use to just before the branch and inserting code before it and then
1080+
// modifying the operand is legal.
1081+
auto *IP = cast<Instruction>(WidenableBR->getCondition());
1082+
IP->moveBefore(WidenableBR);
1083+
Rewriter.setInsertPoint(IP);
1084+
IRBuilder<> B(IP);
1085+
1086+
bool Changed = false;
1087+
Value *MinECV = nullptr; // lazily generated if needed
1088+
for (BasicBlock *ExitingBB : ExitingBlocks) {
1089+
// If our exiting block exits multiple loops, we can only rewrite the
1090+
// innermost one. Otherwise, we're changing how many times the innermost
1091+
// loop runs before it exits.
1092+
if (LI->getLoopFor(ExitingBB) != L)
1093+
continue;
1094+
1095+
// Can't rewrite non-branch yet.
1096+
auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1097+
if (!BI)
1098+
continue;
1099+
1100+
// If already constant, nothing to do.
1101+
if (isa<Constant>(BI->getCondition()))
1102+
continue;
1103+
1104+
const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1105+
if (isa<SCEVCouldNotCompute>(ExitCount) ||
1106+
ExitCount->getType()->isPointerTy() ||
1107+
!isSafeToExpandAt(ExitCount, WidenableBR, *SE))
1108+
continue;
1109+
1110+
const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1111+
BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
1112+
if (!ExitBB->getTerminatingDeoptimizeCall())
1113+
// Profitability: indicator of rarely/never taken exit
1114+
continue;
1115+
1116+
// If we found a widenable exit condition, do two things:
1117+
// 1) fold the widened exit test into the widenable condition
1118+
// 2) fold the branch to untaken - avoids infinite looping
1119+
1120+
Value *ECV = Rewriter.expandCodeFor(ExitCount);
1121+
if (!MinECV)
1122+
MinECV = Rewriter.expandCodeFor(MinEC);
1123+
Value *RHS = MinECV;
1124+
if (ECV->getType() != RHS->getType()) {
1125+
Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1126+
ECV = B.CreateZExt(ECV, WiderTy);
1127+
RHS = B.CreateZExt(RHS, WiderTy);
1128+
}
1129+
assert(!Latch || DT->dominates(ExitingBB, Latch));
1130+
Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
1131+
// Freeze poison or undef to an arbitrary bit pattern to ensure we can
1132+
// branch without introducing UB. See NOTE ON POISON/UNDEF above for
1133+
// context.
1134+
NewCond = B.CreateFreeze(NewCond);
1135+
1136+
Value *Cond, *WC;
1137+
BasicBlock *IfTrueBB, *IfFalseBB;
1138+
bool Success =
1139+
parseWidenableBranch(WidenableBR, Cond, WC, IfTrueBB, IfFalseBB);
1140+
assert(Success && "implied from above");
1141+
(void)Success;
1142+
Instruction *WCAnd = cast<Instruction>(WidenableBR->getCondition());
1143+
WCAnd->setOperand(0, B.CreateAnd(NewCond, Cond));
1144+
1145+
Value *OldCond = BI->getCondition();
1146+
BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
1147+
Changed = true;
1148+
}
1149+
1150+
if (Changed)
1151+
// We just mutated a bunch of loop exits changing there exit counts
1152+
// widely. We need to force recomputation of the exit counts given these
1153+
// changes. Note that all of the inserted exits are never taken, and
1154+
// should be removed next time the CFG is modified.
1155+
SE->forgetLoop(L);
1156+
return Changed;
1157+
}
1158+
9581159
bool LoopPredication::runOnLoop(Loop *Loop) {
9591160
L = Loop;
9601161

@@ -1006,16 +1207,12 @@ bool LoopPredication::runOnLoop(Loop *Loop) {
10061207
cast<BranchInst>(BB->getTerminator()));
10071208
}
10081209

1009-
if (Guards.empty() && GuardsAsWidenableBranches.empty())
1010-
return false;
1011-
10121210
SCEVExpander Expander(*SE, *DL, "loop-predication");
1013-
10141211
bool Changed = false;
10151212
for (auto *Guard : Guards)
10161213
Changed |= widenGuardConditions(Guard, Expander);
10171214
for (auto *Guard : GuardsAsWidenableBranches)
10181215
Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1019-
1216+
Changed |= predicateLoopExits(L, Expander);
10201217
return Changed;
10211218
}

0 commit comments

Comments
 (0)