Skip to content
This repository was archived by the owner on Mar 28, 2020. It is now read-only.

Commit ba89ffc

Browse files
committed
[PM/LoopUnswitch] Fix PR37651 by correctly invalidating SCEV when
unswitching loops. Original patch trying to address this was sent in D47624, but that didn't quite handle things correctly. There are two key principles used to select whether and how to invalidate SCEV-cached information about loops: 1) We must invalidate any info SCEV has cached before unswitching as we may change (or destroy) the loop structure by the act of unswitching, and make it hard to recover everything we want to invalidate within SCEV. 2) We need to invalidate all of the loops whose CFGs are mutated by the unswitching. Notably, this isn't the *entire* loop nest, this is every loop contained by the outermost loop reached by an exit block relevant to the unswitch. And we need to do this even when doing trivial unswitching. I've added more focused tests that directly check that SCEV starts off with imprecise information and after unswitching (and simplifying instructions) re-querying SCEV will produce precise information. These tests also specifically work to check that an *outer* loop's information becomes precise. However, the testing here is still a bit imperfect. Crafting test cases that reliably fail to be analyzed by SCEV before unswitching and succeed afterward proved ... very, very hard. It took me several hours and careful work to build these, and I'm not optimistic about necessarily coming up with more to cover more elaborate possibilities. Fortunately, the code pattern we are testing here in the pass is really straightforward and reliable. Thanks to Max Kazantsev for the initial work on this as well as the review, and to Hal Finkel for helping me talk through approaches to test this stuff even if it didn't come to much. Differential Revision: https://reviews.llvm.org/D47624 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@336183 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 099ee45 commit ba89ffc

File tree

2 files changed

+272
-21
lines changed

2 files changed

+272
-21
lines changed

lib/Transforms/Scalar/SimpleLoopUnswitch.cpp

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,11 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB,
253253
/// (splitting the exit block as necessary). It simplifies the branch within
254254
/// the loop to an unconditional branch but doesn't remove it entirely. Further
255255
/// cleanup can be done with some simplify-cfg like pass.
256+
///
257+
/// If `SE` is not null, it will be updated based on the potential loop SCEVs
258+
/// invalidated by this.
256259
static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
257-
LoopInfo &LI) {
260+
LoopInfo &LI, ScalarEvolution *SE) {
258261
assert(BI.isConditional() && "Can only unswitch a conditional branch!");
259262
LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n");
260263

@@ -318,6 +321,16 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
318321
}
319322
});
320323

324+
// If we have scalar evolutions, we need to invalidate them including this
325+
// loop and the loop containing the exit block.
326+
if (SE) {
327+
if (Loop *ExitL = LI.getLoopFor(LoopExitBB))
328+
SE->forgetLoop(ExitL);
329+
else
330+
// Forget the entire nest as this exits the entire nest.
331+
SE->forgetTopmostLoop(&L);
332+
}
333+
321334
// Split the preheader, so that we know that there is a safe place to insert
322335
// the conditional branch. We will change the preheader to have a conditional
323336
// branch on LoopCond.
@@ -420,8 +433,11 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
420433
/// switch will not be revisited. If after unswitching there is only a single
421434
/// in-loop successor, the switch is further simplified to an unconditional
422435
/// branch. Still more cleanup can be done with some simplify-cfg like pass.
436+
///
437+
/// If `SE` is not null, it will be updated based on the potential loop SCEVs
438+
/// invalidated by this.
423439
static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
424-
LoopInfo &LI) {
440+
LoopInfo &LI, ScalarEvolution *SE) {
425441
LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n");
426442
Value *LoopCond = SI.getCondition();
427443

@@ -448,18 +464,33 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
448464

449465
LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n");
450466

467+
// We may need to invalidate SCEVs for the outermost loop reached by any of
468+
// the exits.
469+
Loop *OuterL = &L;
470+
451471
SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases;
452472
ExitCases.reserve(ExitCaseIndices.size());
453473
// We walk the case indices backwards so that we remove the last case first
454474
// and don't disrupt the earlier indices.
455475
for (unsigned Index : reverse(ExitCaseIndices)) {
456476
auto CaseI = SI.case_begin() + Index;
477+
// Compute the outer loop from this exit.
478+
Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor());
479+
if (!ExitL || ExitL->contains(OuterL))
480+
OuterL = ExitL;
457481
// Save the value of this case.
458482
ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()});
459483
// Delete the unswitched cases.
460484
SI.removeCase(CaseI);
461485
}
462486

487+
if (SE) {
488+
if (OuterL)
489+
SE->forgetLoop(OuterL);
490+
else
491+
SE->forgetTopmostLoop(&L);
492+
}
493+
463494
// Check if after this all of the remaining cases point at the same
464495
// successor.
465496
BasicBlock *CommonSuccBB = nullptr;
@@ -617,8 +648,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
617648
///
618649
/// The return value indicates whether anything was unswitched (and therefore
619650
/// changed).
651+
///
652+
/// If `SE` is not null, it will be updated based on the potential loop SCEVs
653+
/// invalidated by this.
620654
static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
621-
LoopInfo &LI) {
655+
LoopInfo &LI, ScalarEvolution *SE) {
622656
bool Changed = false;
623657

624658
// If loop header has only one reachable successor we should keep looking for
@@ -652,7 +686,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
652686
if (isa<Constant>(SI->getCondition()))
653687
return Changed;
654688

655-
if (!unswitchTrivialSwitch(L, *SI, DT, LI))
689+
if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE))
656690
// Couldn't unswitch this one so we're done.
657691
return Changed;
658692

@@ -684,7 +718,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
684718

685719
// Found a trivial condition candidate: non-foldable conditional branch. If
686720
// we fail to unswitch this, we can't do anything else that is trivial.
687-
if (!unswitchTrivialBranch(L, *BI, DT, LI))
721+
if (!unswitchTrivialBranch(L, *BI, DT, LI, SE))
688722
return Changed;
689723

690724
// Mark that we managed to unswitch something.
@@ -1622,7 +1656,8 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
16221656
static bool unswitchNontrivialInvariants(
16231657
Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants,
16241658
DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
1625-
function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
1659+
function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
1660+
ScalarEvolution *SE) {
16261661
auto *ParentBB = TI.getParent();
16271662
BranchInst *BI = dyn_cast<BranchInst>(&TI);
16281663
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
@@ -1705,6 +1740,16 @@ static bool unswitchNontrivialInvariants(
17051740
OuterExitL = NewOuterExitL;
17061741
}
17071742

1743+
// At this point, we're definitely going to unswitch something so invalidate
1744+
// any cached information in ScalarEvolution for the outer most loop
1745+
// containing an exit block and all nested loops.
1746+
if (SE) {
1747+
if (OuterExitL)
1748+
SE->forgetLoop(OuterExitL);
1749+
else
1750+
SE->forgetTopmostLoop(&L);
1751+
}
1752+
17081753
// If the edge from this terminator to a successor dominates that successor,
17091754
// store a map from each block in its dominator subtree to it. This lets us
17101755
// tell when cloning for a particular successor if a block is dominated by
@@ -1968,10 +2013,11 @@ computeDomSubtreeCost(DomTreeNode &N,
19682013
return Cost;
19692014
}
19702015

1971-
static bool unswitchBestCondition(
1972-
Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
1973-
TargetTransformInfo &TTI,
1974-
function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
2016+
static bool
2017+
unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
2018+
AssumptionCache &AC, TargetTransformInfo &TTI,
2019+
function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
2020+
ScalarEvolution *SE) {
19752021
// Collect all invariant conditions within this loop (as opposed to an inner
19762022
// loop which would be handled when visiting that inner loop).
19772023
SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4>
@@ -2164,7 +2210,7 @@ static bool unswitchBestCondition(
21642210
<< BestUnswitchCost << ") terminator: " << *BestUnswitchTI
21652211
<< "\n");
21662212
return unswitchNontrivialInvariants(
2167-
L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB);
2213+
L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE);
21682214
}
21692215

21702216
/// Unswitch control flow predicated on loop invariant conditions.
@@ -2173,10 +2219,25 @@ static bool unswitchBestCondition(
21732219
/// require duplicating any part of the loop) out of the loop body. It then
21742220
/// looks at other loop invariant control flows and tries to unswitch those as
21752221
/// well by cloning the loop if the result is small enough.
2176-
static bool
2177-
unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
2178-
TargetTransformInfo &TTI, bool NonTrivial,
2179-
function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
2222+
///
2223+
/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also
2224+
/// updated based on the unswitch.
2225+
///
2226+
/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is
2227+
/// true, we will attempt to do non-trivial unswitching as well as trivial
2228+
/// unswitching.
2229+
///
2230+
/// The `UnswitchCB` callback provided will be run after unswitching is
2231+
/// complete, with the first parameter set to `true` if the provided loop
2232+
/// remains a loop, and a list of new sibling loops created.
2233+
///
2234+
/// If `SE` is non-null, we will update that analysis based on the unswitching
2235+
/// done.
2236+
static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
2237+
AssumptionCache &AC, TargetTransformInfo &TTI,
2238+
bool NonTrivial,
2239+
function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
2240+
ScalarEvolution *SE) {
21802241
assert(L.isRecursivelyLCSSAForm(DT, LI) &&
21812242
"Loops must be in LCSSA form before unswitching.");
21822243
bool Changed = false;
@@ -2186,7 +2247,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
21862247
return false;
21872248

21882249
// Try trivial unswitch first before loop over other basic blocks in the loop.
2189-
if (unswitchAllTrivialConditions(L, DT, LI)) {
2250+
if (unswitchAllTrivialConditions(L, DT, LI, SE)) {
21902251
// If we unswitched successfully we will want to clean up the loop before
21912252
// processing it further so just mark it as unswitched and return.
21922253
UnswitchCB(/*CurrentLoopValid*/ true, {});
@@ -2207,7 +2268,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
22072268

22082269
// Try to unswitch the best invariant condition. We prefer this full unswitch to
22092270
// a partial unswitch when possible below the threshold.
2210-
if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB))
2271+
if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE))
22112272
return true;
22122273

22132274
// No other opportunities to unswitch.
@@ -2241,8 +2302,8 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
22412302
U.markLoopAsDeleted(L, LoopName);
22422303
};
22432304

2244-
if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial,
2245-
UnswitchCB))
2305+
if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB,
2306+
&AR.SE))
22462307
return PreservedAnalyses::all();
22472308

22482309
// Historically this pass has had issues with the dominator tree so verify it
@@ -2290,6 +2351,9 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
22902351
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
22912352
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
22922353

2354+
auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
2355+
auto *SE = SEWP ? &SEWP->getSE() : nullptr;
2356+
22932357
auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid,
22942358
ArrayRef<Loop *> NewLoops) {
22952359
// If we did a non-trivial unswitch, we have added new (cloned) loops.
@@ -2305,8 +2369,7 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
23052369
LPM.markLoopAsDeleted(*L);
23062370
};
23072371

2308-
bool Changed =
2309-
unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB);
2372+
bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE);
23102373

23112374
// If anything was unswitched, also clear any cached information about this
23122375
// loop.

0 commit comments

Comments
 (0)