@@ -253,8 +253,11 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB,
253
253
// / (splitting the exit block as necessary). It simplifies the branch within
254
254
// / the loop to an unconditional branch but doesn't remove it entirely. Further
255
255
// / cleanup can be done with some simplify-cfg like pass.
256
+ // /
257
+ // / If `SE` is not null, it will be updated based on the potential loop SCEVs
258
+ // / invalidated by this.
256
259
static bool unswitchTrivialBranch (Loop &L, BranchInst &BI, DominatorTree &DT,
257
- LoopInfo &LI) {
260
+ LoopInfo &LI, ScalarEvolution *SE ) {
258
261
assert (BI.isConditional () && " Can only unswitch a conditional branch!" );
259
262
LLVM_DEBUG (dbgs () << " Trying to unswitch branch: " << BI << " \n " );
260
263
@@ -318,6 +321,16 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
318
321
}
319
322
});
320
323
324
+ // If we have scalar evolutions, we need to invalidate them including this
325
+ // loop and the loop containing the exit block.
326
+ if (SE) {
327
+ if (Loop *ExitL = LI.getLoopFor (LoopExitBB))
328
+ SE->forgetLoop (ExitL);
329
+ else
330
+ // Forget the entire nest as this exits the entire nest.
331
+ SE->forgetTopmostLoop (&L);
332
+ }
333
+
321
334
// Split the preheader, so that we know that there is a safe place to insert
322
335
// the conditional branch. We will change the preheader to have a conditional
323
336
// branch on LoopCond.
@@ -420,8 +433,11 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
420
433
// / switch will not be revisited. If after unswitching there is only a single
421
434
// / in-loop successor, the switch is further simplified to an unconditional
422
435
// / branch. Still more cleanup can be done with some simplify-cfg like pass.
436
+ // /
437
+ // / If `SE` is not null, it will be updated based on the potential loop SCEVs
438
+ // / invalidated by this.
423
439
static bool unswitchTrivialSwitch (Loop &L, SwitchInst &SI, DominatorTree &DT,
424
- LoopInfo &LI) {
440
+ LoopInfo &LI, ScalarEvolution *SE ) {
425
441
LLVM_DEBUG (dbgs () << " Trying to unswitch switch: " << SI << " \n " );
426
442
Value *LoopCond = SI.getCondition ();
427
443
@@ -448,18 +464,33 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
448
464
449
465
LLVM_DEBUG (dbgs () << " unswitching trivial cases...\n " );
450
466
467
+ // We may need to invalidate SCEVs for the outermost loop reached by any of
468
+ // the exits.
469
+ Loop *OuterL = &L;
470
+
451
471
SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4 > ExitCases;
452
472
ExitCases.reserve (ExitCaseIndices.size ());
453
473
// We walk the case indices backwards so that we remove the last case first
454
474
// and don't disrupt the earlier indices.
455
475
for (unsigned Index : reverse (ExitCaseIndices)) {
456
476
auto CaseI = SI.case_begin () + Index;
477
+ // Compute the outer loop from this exit.
478
+ Loop *ExitL = LI.getLoopFor (CaseI->getCaseSuccessor ());
479
+ if (!ExitL || ExitL->contains (OuterL))
480
+ OuterL = ExitL;
457
481
// Save the value of this case.
458
482
ExitCases.push_back ({CaseI->getCaseValue (), CaseI->getCaseSuccessor ()});
459
483
// Delete the unswitched cases.
460
484
SI.removeCase (CaseI);
461
485
}
462
486
487
+ if (SE) {
488
+ if (OuterL)
489
+ SE->forgetLoop (OuterL);
490
+ else
491
+ SE->forgetTopmostLoop (&L);
492
+ }
493
+
463
494
// Check if after this all of the remaining cases point at the same
464
495
// successor.
465
496
BasicBlock *CommonSuccBB = nullptr ;
@@ -617,8 +648,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
617
648
// /
618
649
// / The return value indicates whether anything was unswitched (and therefore
619
650
// / changed).
651
+ // /
652
+ // / If `SE` is not null, it will be updated based on the potential loop SCEVs
653
+ // / invalidated by this.
620
654
static bool unswitchAllTrivialConditions (Loop &L, DominatorTree &DT,
621
- LoopInfo &LI) {
655
+ LoopInfo &LI, ScalarEvolution *SE ) {
622
656
bool Changed = false ;
623
657
624
658
// If loop header has only one reachable successor we should keep looking for
@@ -652,7 +686,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
652
686
if (isa<Constant>(SI->getCondition ()))
653
687
return Changed;
654
688
655
- if (!unswitchTrivialSwitch (L, *SI, DT, LI))
689
+ if (!unswitchTrivialSwitch (L, *SI, DT, LI, SE ))
656
690
// Couldn't unswitch this one so we're done.
657
691
return Changed;
658
692
@@ -684,7 +718,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
684
718
685
719
// Found a trivial condition candidate: non-foldable conditional branch. If
686
720
// we fail to unswitch this, we can't do anything else that is trivial.
687
- if (!unswitchTrivialBranch (L, *BI, DT, LI))
721
+ if (!unswitchTrivialBranch (L, *BI, DT, LI, SE ))
688
722
return Changed;
689
723
690
724
// Mark that we managed to unswitch something.
@@ -1622,7 +1656,8 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
1622
1656
static bool unswitchNontrivialInvariants (
1623
1657
Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants,
1624
1658
DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
1625
- function_ref<void (bool , ArrayRef<Loop *>)> UnswitchCB) {
1659
+ function_ref<void (bool , ArrayRef<Loop *>)> UnswitchCB,
1660
+ ScalarEvolution *SE) {
1626
1661
auto *ParentBB = TI.getParent ();
1627
1662
BranchInst *BI = dyn_cast<BranchInst>(&TI);
1628
1663
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
@@ -1705,6 +1740,16 @@ static bool unswitchNontrivialInvariants(
1705
1740
OuterExitL = NewOuterExitL;
1706
1741
}
1707
1742
1743
+ // At this point, we're definitely going to unswitch something so invalidate
1744
+ // any cached information in ScalarEvolution for the outer most loop
1745
+ // containing an exit block and all nested loops.
1746
+ if (SE) {
1747
+ if (OuterExitL)
1748
+ SE->forgetLoop (OuterExitL);
1749
+ else
1750
+ SE->forgetTopmostLoop (&L);
1751
+ }
1752
+
1708
1753
// If the edge from this terminator to a successor dominates that successor,
1709
1754
// store a map from each block in its dominator subtree to it. This lets us
1710
1755
// tell when cloning for a particular successor if a block is dominated by
@@ -1968,10 +2013,11 @@ computeDomSubtreeCost(DomTreeNode &N,
1968
2013
return Cost;
1969
2014
}
1970
2015
1971
- static bool unswitchBestCondition (
1972
- Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
1973
- TargetTransformInfo &TTI,
1974
- function_ref<void (bool , ArrayRef<Loop *>)> UnswitchCB) {
2016
+ static bool
2017
+ unswitchBestCondition (Loop &L, DominatorTree &DT, LoopInfo &LI,
2018
+ AssumptionCache &AC, TargetTransformInfo &TTI,
2019
+ function_ref<void (bool , ArrayRef<Loop *>)> UnswitchCB,
2020
+ ScalarEvolution *SE) {
1975
2021
// Collect all invariant conditions within this loop (as opposed to an inner
1976
2022
// loop which would be handled when visiting that inner loop).
1977
2023
SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4 >
@@ -2164,7 +2210,7 @@ static bool unswitchBestCondition(
2164
2210
<< BestUnswitchCost << " ) terminator: " << *BestUnswitchTI
2165
2211
<< " \n " );
2166
2212
return unswitchNontrivialInvariants (
2167
- L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB);
2213
+ L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE );
2168
2214
}
2169
2215
2170
2216
// / Unswitch control flow predicated on loop invariant conditions.
@@ -2173,10 +2219,25 @@ static bool unswitchBestCondition(
2173
2219
// / require duplicating any part of the loop) out of the loop body. It then
2174
2220
// / looks at other loop invariant control flows and tries to unswitch those as
2175
2221
// / well by cloning the loop if the result is small enough.
2176
- static bool
2177
- unswitchLoop (Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
2178
- TargetTransformInfo &TTI, bool NonTrivial,
2179
- function_ref<void (bool , ArrayRef<Loop *>)> UnswitchCB) {
2222
+ // /
2223
+ // / The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also
2224
+ // / updated based on the unswitch.
2225
+ // /
2226
+ // / If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is
2227
+ // / true, we will attempt to do non-trivial unswitching as well as trivial
2228
+ // / unswitching.
2229
+ // /
2230
+ // / The `UnswitchCB` callback provided will be run after unswitching is
2231
+ // / complete, with the first parameter set to `true` if the provided loop
2232
+ // / remains a loop, and a list of new sibling loops created.
2233
+ // /
2234
+ // / If `SE` is non-null, we will update that analysis based on the unswitching
2235
+ // / done.
2236
+ static bool unswitchLoop (Loop &L, DominatorTree &DT, LoopInfo &LI,
2237
+ AssumptionCache &AC, TargetTransformInfo &TTI,
2238
+ bool NonTrivial,
2239
+ function_ref<void (bool , ArrayRef<Loop *>)> UnswitchCB,
2240
+ ScalarEvolution *SE) {
2180
2241
assert (L.isRecursivelyLCSSAForm (DT, LI) &&
2181
2242
" Loops must be in LCSSA form before unswitching." );
2182
2243
bool Changed = false ;
@@ -2186,7 +2247,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
2186
2247
return false ;
2187
2248
2188
2249
// Try trivial unswitch first before loop over other basic blocks in the loop.
2189
- if (unswitchAllTrivialConditions (L, DT, LI)) {
2250
+ if (unswitchAllTrivialConditions (L, DT, LI, SE )) {
2190
2251
// If we unswitched successfully we will want to clean up the loop before
2191
2252
// processing it further so just mark it as unswitched and return.
2192
2253
UnswitchCB (/* CurrentLoopValid*/ true , {});
@@ -2207,7 +2268,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
2207
2268
2208
2269
// Try to unswitch the best invariant condition. We prefer this full unswitch to
2209
2270
// a partial unswitch when possible below the threshold.
2210
- if (unswitchBestCondition (L, DT, LI, AC, TTI, UnswitchCB))
2271
+ if (unswitchBestCondition (L, DT, LI, AC, TTI, UnswitchCB, SE ))
2211
2272
return true ;
2212
2273
2213
2274
// No other opportunities to unswitch.
@@ -2241,8 +2302,8 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
2241
2302
U.markLoopAsDeleted (L, LoopName);
2242
2303
};
2243
2304
2244
- if (!unswitchLoop (L, AR.DT , AR.LI , AR.AC , AR.TTI , NonTrivial,
2245
- UnswitchCB ))
2305
+ if (!unswitchLoop (L, AR.DT , AR.LI , AR.AC , AR.TTI , NonTrivial, UnswitchCB,
2306
+ &AR. SE ))
2246
2307
return PreservedAnalyses::all ();
2247
2308
2248
2309
// Historically this pass has had issues with the dominator tree so verify it
@@ -2290,6 +2351,9 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
2290
2351
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
2291
2352
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
2292
2353
2354
+ auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
2355
+ auto *SE = SEWP ? &SEWP->getSE () : nullptr ;
2356
+
2293
2357
auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid,
2294
2358
ArrayRef<Loop *> NewLoops) {
2295
2359
// If we did a non-trivial unswitch, we have added new (cloned) loops.
@@ -2305,8 +2369,7 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
2305
2369
LPM.markLoopAsDeleted (*L);
2306
2370
};
2307
2371
2308
- bool Changed =
2309
- unswitchLoop (*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB);
2372
+ bool Changed = unswitchLoop (*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE);
2310
2373
2311
2374
// If anything was unswitched, also clear any cached information about this
2312
2375
// loop.
0 commit comments