Skip to content

Commit 4d85285

Browse files
authored
[SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (#105636)
If we switch over ucmp/scmp and have two switch cases going to the same destination, we can convert into icmp+br. Fixes #105632.
1 parent 58ac764 commit 4d85285

File tree

2 files changed

+486
-46
lines changed

2 files changed

+486
-46
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7131,6 +7131,119 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
71317131
return true;
71327132
}
71337133

7134+
/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
7135+
/// the same destination.
7136+
static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
7137+
DomTreeUpdater *DTU) {
7138+
auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
7139+
if (!Cmp || !Cmp->hasOneUse())
7140+
return false;
7141+
7142+
SmallVector<uint32_t, 4> Weights;
7143+
bool HasWeights = extractBranchWeights(getBranchWeightMDNode(*SI), Weights);
7144+
if (!HasWeights)
7145+
Weights.resize(4); // Avoid checking HasWeights everywhere.
7146+
7147+
// Normalize to [us]cmp == Res ? Succ : OtherSucc.
7148+
int64_t Res;
7149+
BasicBlock *Succ, *OtherSucc;
7150+
uint32_t SuccWeight = 0, OtherSuccWeight = 0;
7151+
BasicBlock *Unreachable = nullptr;
7152+
7153+
if (SI->getNumCases() == 2) {
7154+
// Find which of 1, 0 or -1 is missing (handled by default dest).
7155+
SmallSet<int64_t, 3> Missing;
7156+
Missing.insert(1);
7157+
Missing.insert(0);
7158+
Missing.insert(-1);
7159+
7160+
Succ = SI->getDefaultDest();
7161+
SuccWeight = Weights[0];
7162+
OtherSucc = nullptr;
7163+
for (auto &Case : SI->cases()) {
7164+
std::optional<int64_t> Val =
7165+
Case.getCaseValue()->getValue().trySExtValue();
7166+
if (!Val)
7167+
return false;
7168+
if (!Missing.erase(*Val))
7169+
return false;
7170+
if (OtherSucc && OtherSucc != Case.getCaseSuccessor())
7171+
return false;
7172+
OtherSucc = Case.getCaseSuccessor();
7173+
OtherSuccWeight += Weights[Case.getSuccessorIndex()];
7174+
}
7175+
7176+
assert(Missing.size() == 1 && "Should have one case left");
7177+
Res = *Missing.begin();
7178+
} else if (SI->getNumCases() == 3 && SI->defaultDestUndefined()) {
7179+
// Normalize so that Succ is taken once and OtherSucc twice.
7180+
Unreachable = SI->getDefaultDest();
7181+
Succ = OtherSucc = nullptr;
7182+
for (auto &Case : SI->cases()) {
7183+
BasicBlock *NewSucc = Case.getCaseSuccessor();
7184+
uint32_t Weight = Weights[Case.getSuccessorIndex()];
7185+
if (!OtherSucc || OtherSucc == NewSucc) {
7186+
OtherSucc = NewSucc;
7187+
OtherSuccWeight += Weight;
7188+
} else if (!Succ) {
7189+
Succ = NewSucc;
7190+
SuccWeight = Weight;
7191+
} else if (Succ == NewSucc) {
7192+
std::swap(Succ, OtherSucc);
7193+
std::swap(SuccWeight, OtherSuccWeight);
7194+
} else
7195+
return false;
7196+
}
7197+
for (auto &Case : SI->cases()) {
7198+
std::optional<int64_t> Val =
7199+
Case.getCaseValue()->getValue().trySExtValue();
7200+
if (!Val || (Val != 1 && Val != 0 && Val != -1))
7201+
return false;
7202+
if (Case.getCaseSuccessor() == Succ) {
7203+
Res = *Val;
7204+
break;
7205+
}
7206+
}
7207+
} else {
7208+
return false;
7209+
}
7210+
7211+
// Determine predicate for the missing case.
7212+
ICmpInst::Predicate Pred;
7213+
switch (Res) {
7214+
case 1:
7215+
Pred = ICmpInst::ICMP_UGT;
7216+
break;
7217+
case 0:
7218+
Pred = ICmpInst::ICMP_EQ;
7219+
break;
7220+
case -1:
7221+
Pred = ICmpInst::ICMP_ULT;
7222+
break;
7223+
}
7224+
if (Cmp->isSigned())
7225+
Pred = ICmpInst::getSignedPredicate(Pred);
7226+
7227+
MDNode *NewWeights = nullptr;
7228+
if (HasWeights)
7229+
NewWeights = MDBuilder(SI->getContext())
7230+
.createBranchWeights(SuccWeight, OtherSuccWeight);
7231+
7232+
BasicBlock *BB = SI->getParent();
7233+
Builder.SetInsertPoint(SI->getIterator());
7234+
Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
7235+
Builder.CreateCondBr(ICmp, Succ, OtherSucc, NewWeights,
7236+
SI->getMetadata(LLVMContext::MD_unpredictable));
7237+
OtherSucc->removePredecessor(BB);
7238+
if (Unreachable)
7239+
Unreachable->removePredecessor(BB);
7240+
SI->eraseFromParent();
7241+
Cmp->eraseFromParent();
7242+
if (DTU && Unreachable)
7243+
DTU->applyUpdates({{DominatorTree::Delete, BB, Unreachable}});
7244+
return true;
7245+
}
7246+
71347247
bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
71357248
BasicBlock *BB = SI->getParent();
71367249

@@ -7163,6 +7276,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
71637276
if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
71647277
return requestResimplify();
71657278

7279+
if (simplifySwitchOfCmpIntrinsic(SI, Builder, DTU))
7280+
return requestResimplify();
7281+
71667282
if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
71677283
return requestResimplify();
71687284

0 commit comments

Comments
 (0)