Skip to content

Commit 16b85a6

Browse files
committed
Fix a bug when unswitching on partial LIV for SwitchInst
Summary: Fix a bug when unswitching on partial LIV for SwitchInst. Reviewers: hfinkel, efriedma, sanjoy Reviewed By: sanjoy Subscribers: david2050, mzolotukhin, llvm-commits Differential Revision: https://reviews.llvm.org/D29107 llvm-svn: 296363
1 parent 08d0840 commit 16b85a6

File tree

2 files changed

+339
-32
lines changed

2 files changed

+339
-32
lines changed

llvm/lib/Transforms/Scalar/LoopUnswitch.cpp

Lines changed: 128 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,27 @@ Pass *llvm::createLoopUnswitchPass(bool Os) {
374374
return new LoopUnswitch(Os);
375375
}
376376

377+
/// Operator chain lattice.
378+
enum OperatorChain {
379+
OC_OpChainNone, ///< There is no operator.
380+
OC_OpChainOr, ///< There are only ORs.
381+
OC_OpChainAnd, ///< There are only ANDs.
382+
OC_OpChainMixed ///< There are ANDs and ORs.
383+
};
384+
377385
/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
378386
/// an invariant piece, return the invariant. Otherwise, return null.
387+
//
388+
/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a
389+
/// mixed operator chain, as we can not reliably find a value which will simplify
390+
/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0
391+
/// to simplify the chain.
392+
///
393+
/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to
394+
/// simplify the condition itself to a loop variant condition, but at the
395+
/// cost of creating an entirely new loop.
379396
static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
397+
OperatorChain &ParentChain,
380398
DenseMap<Value *, Value *> &Cache) {
381399
auto CacheIt = Cache.find(Cond);
382400
if (CacheIt != Cache.end())
@@ -400,31 +418,75 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
400418
return Cond;
401419
}
402420

421+
// Walk up the operator chain to find partial invariant conditions.
403422
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))
404423
if (BO->getOpcode() == Instruction::And ||
405424
BO->getOpcode() == Instruction::Or) {
406-
// If either the left or right side is invariant, we can unswitch on this,
407-
// which will cause the branch to go away in one loop and the condition to
408-
// simplify in the other one.
409-
if (Value *LHS =
410-
FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) {
411-
Cache[Cond] = LHS;
412-
return LHS;
425+
// Given the previous operator, compute the current operator chain status.
426+
OperatorChain NewChain;
427+
switch (ParentChain) {
428+
case OC_OpChainNone:
429+
NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
430+
OC_OpChainOr;
431+
break;
432+
case OC_OpChainOr:
433+
NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr :
434+
OC_OpChainMixed;
435+
break;
436+
case OC_OpChainAnd:
437+
NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
438+
OC_OpChainMixed;
439+
break;
440+
case OC_OpChainMixed:
441+
NewChain = OC_OpChainMixed;
442+
break;
413443
}
414-
if (Value *RHS =
415-
FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) {
416-
Cache[Cond] = RHS;
417-
return RHS;
444+
445+
// If we reach a Mixed state, we do not want to keep walking up as we can not
446+
// reliably find a value that will simplify the chain. With this check, we
447+
// will return null on the first sight of mixed chain and the caller will
448+
// either backtrack to find partial LIV in other operand or return null.
449+
if (NewChain != OC_OpChainMixed) {
450+
// Update the current operator chain type before we search up the chain.
451+
ParentChain = NewChain;
452+
// If either the left or right side is invariant, we can unswitch on this,
453+
// which will cause the branch to go away in one loop and the condition to
454+
// simplify in the other one.
455+
if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed,
456+
ParentChain, Cache)) {
457+
Cache[Cond] = LHS;
458+
return LHS;
459+
}
460+
// We did not manage to find a partial LIV in operand(0). Backtrack and try
461+
// operand(1).
462+
ParentChain = NewChain;
463+
if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed,
464+
ParentChain, Cache)) {
465+
Cache[Cond] = RHS;
466+
return RHS;
467+
}
418468
}
419469
}
420470

421471
Cache[Cond] = nullptr;
422472
return nullptr;
423473
}
424474

425-
static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
475+
/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
476+
/// an invariant piece, return the invariant along with the operator chain type.
477+
/// Otherwise, return null.
478+
static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond,
479+
Loop *L,
480+
bool &Changed) {
426481
DenseMap<Value *, Value *> Cache;
427-
return FindLIVLoopCondition(Cond, L, Changed, Cache);
482+
OperatorChain OpChain = OC_OpChainNone;
483+
Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache);
484+
485+
// In case we do find a LIV, it can not be obtained by walking up a mixed
486+
// operator chain.
487+
assert((!FCond || OpChain != OC_OpChainMixed) &&
488+
"Do not expect a partial LIV with mixed operator chain");
489+
return {FCond, OpChain};
428490
}
429491

430492
bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
@@ -556,7 +618,7 @@ bool LoopUnswitch::processCurrentLoop() {
556618

557619
for (IntrinsicInst *Guard : Guards) {
558620
Value *LoopCond =
559-
FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed);
621+
FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first;
560622
if (LoopCond &&
561623
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
562624
// NB! Unswitching (if successful) could have erased some of the
@@ -597,32 +659,57 @@ bool LoopUnswitch::processCurrentLoop() {
597659
// See if this, or some part of it, is loop invariant. If so, we can
598660
// unswitch on it if we desire.
599661
Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
600-
currentLoop, Changed);
662+
currentLoop, Changed).first;
601663
if (LoopCond &&
602664
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
603665
++NumBranches;
604666
return true;
605667
}
606668
}
607669
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
608-
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
609-
currentLoop, Changed);
670+
Value *SC = SI->getCondition();
671+
Value *LoopCond;
672+
OperatorChain OpChain;
673+
std::tie(LoopCond, OpChain) =
674+
FindLIVLoopCondition(SC, currentLoop, Changed);
675+
610676
unsigned NumCases = SI->getNumCases();
611677
if (LoopCond && NumCases) {
612678
// Find a value to unswitch on:
613679
// FIXME: this should chose the most expensive case!
614680
// FIXME: scan for a case with a non-critical edge?
615681
Constant *UnswitchVal = nullptr;
616-
617-
// Do not process same value again and again.
618-
// At this point we have some cases already unswitched and
619-
// some not yet unswitched. Let's find the first not yet unswitched one.
620-
for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
621-
i != e; ++i) {
622-
Constant *UnswitchValCandidate = i.getCaseValue();
623-
if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
624-
UnswitchVal = UnswitchValCandidate;
625-
break;
682+
// Find a case value such that at least one case value is unswitched
683+
// out.
684+
if (OpChain == OC_OpChainAnd) {
685+
// If the chain only has ANDs and the switch has a case value of 0.
686+
// Dropping in a 0 to the chain will unswitch out the 0-casevalue.
687+
auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType()));
688+
if (BranchesInfo.isUnswitched(SI, AllZero))
689+
continue;
690+
// We are unswitching 0 out.
691+
UnswitchVal = AllZero;
692+
} else if (OpChain == OC_OpChainOr) {
693+
// If the chain only has ORs and the switch has a case value of ~0.
694+
// Dropping in a ~0 to the chain will unswitch out the ~0-casevalue.
695+
auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType()));
696+
if (BranchesInfo.isUnswitched(SI, AllOne))
697+
continue;
698+
// We are unswitching ~0 out.
699+
UnswitchVal = AllOne;
700+
} else {
701+
assert(OpChain == OC_OpChainNone &&
702+
"Expect to unswitch on trivial chain");
703+
// Do not process same value again and again.
704+
// At this point we have some cases already unswitched and
705+
// some not yet unswitched. Let's find the first not yet unswitched one.
706+
for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
707+
i != e; ++i) {
708+
Constant *UnswitchValCandidate = i.getCaseValue();
709+
if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
710+
UnswitchVal = UnswitchValCandidate;
711+
break;
712+
}
626713
}
627714
}
628715

@@ -631,6 +718,11 @@ bool LoopUnswitch::processCurrentLoop() {
631718

632719
if (UnswitchIfProfitable(LoopCond, UnswitchVal)) {
633720
++NumSwitches;
721+
// In case of a full LIV, UnswitchVal is the value we unswitched out.
722+
// In case of a partial LIV, we only unswitch when its an AND-chain
723+
// or OR-chain. In both cases switch input value simplifies to
724+
// UnswitchVal.
725+
BranchesInfo.setUnswitched(SI, UnswitchVal);
634726
return true;
635727
}
636728
}
@@ -641,7 +733,7 @@ bool LoopUnswitch::processCurrentLoop() {
641733
BBI != E; ++BBI)
642734
if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
643735
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
644-
currentLoop, Changed);
736+
currentLoop, Changed).first;
645737
if (LoopCond && UnswitchIfProfitable(LoopCond,
646738
ConstantInt::getTrue(Context))) {
647739
++NumSelects;
@@ -900,7 +992,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
900992
return false;
901993

902994
Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
903-
currentLoop, Changed);
995+
currentLoop, Changed).first;
904996

905997
// Unswitch only if the trivial condition itself is an LIV (not
906998
// partial LIV which could occur in and/or)
@@ -931,7 +1023,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
9311023
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
9321024
// If this isn't switching on an invariant condition, we can't unswitch it.
9331025
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
934-
currentLoop, Changed);
1026+
currentLoop, Changed).first;
9351027

9361028
// Unswitch only if the trivial condition itself is an LIV (not
9371029
// partial LIV which could occur in and/or)
@@ -969,6 +1061,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
9691061

9701062
UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,
9711063
nullptr);
1064+
1065+
// We are only unswitching full LIV.
1066+
BranchesInfo.setUnswitched(SI, CondVal);
9721067
++NumSwitches;
9731068
return true;
9741069
}
@@ -1250,6 +1345,9 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
12501345
SwitchInst *SI = dyn_cast<SwitchInst>(UI);
12511346
if (!SI || !isa<ConstantInt>(Val)) continue;
12521347

1348+
// NOTE: if a case value for the switch is unswitched out, we record it
1349+
// after the unswitch finishes. We can not record it here as the switch
1350+
// is not a direct user of the partial LIV.
12531351
SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val));
12541352
// Default case is live for multiple values.
12551353
if (DeadCase == SI->case_default()) continue;
@@ -1262,8 +1360,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
12621360
BasicBlock *SISucc = DeadCase.getCaseSuccessor();
12631361
BasicBlock *Latch = L->getLoopLatch();
12641362

1265-
BranchesInfo.setUnswitched(SI, Val);
1266-
12671363
if (!SI->findCaseDest(SISucc)) continue; // Edge is critical.
12681364
// If the DeadCase successor dominates the loop latch, then the
12691365
// transformation isn't safe since it will delete the sole predecessor edge

0 commit comments

Comments
 (0)