Skip to content

Commit 72ebbc4

Browse files
committed
[SwitchLowering] Support merging 0 and power-of-2 case.
1 parent 894a0dd commit 72ebbc4

File tree

6 files changed

+178
-165
lines changed

6 files changed

+178
-165
lines changed

llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,13 @@ class IRTranslator : public MachineFunctionPass {
405405
BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I,
406406
MachineBasicBlock *Fallthrough, bool FallthroughUnreachable);
407407

408-
bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
409-
MachineBasicBlock *Fallthrough,
410-
bool FallthroughUnreachable,
411-
BranchProbability UnhandledProbs,
412-
MachineBasicBlock *CurMBB,
413-
MachineIRBuilder &MIB,
414-
MachineBasicBlock *SwitchMBB);
408+
bool lowerSwitchAndOrRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
409+
MachineBasicBlock *Fallthrough,
410+
bool FallthroughUnreachable,
411+
BranchProbability UnhandledProbs,
412+
MachineBasicBlock *CurMBB,
413+
MachineIRBuilder &MIB,
414+
MachineBasicBlock *SwitchMBB);
415415

416416
bool lowerBitTestWorkItem(
417417
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,

llvm/include/llvm/CodeGen/SwitchLoweringUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ enum CaseClusterKind {
3535
/// A cluster of cases suitable for jump table lowering.
3636
CC_JumpTable,
3737
/// A cluster of cases suitable for bit test lowering.
38-
CC_BitTests
38+
CC_BitTests,
39+
CC_And
3940
};
4041

4142
/// A cluster of case labels.
@@ -141,6 +142,8 @@ struct CaseBlock {
141142
BranchProbability TrueProb, FalseProb;
142143
bool IsUnpredictable;
143144

145+
bool EmitAnd = false;
146+
144147
// Constructor for SelectionDAG.
145148
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
146149
const Value *cmpmiddle, MachineBasicBlock *truebb,

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,18 +1059,15 @@ bool IRTranslator::lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
10591059
}
10601060
return true;
10611061
}
1062-
bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
1063-
Value *Cond,
1064-
MachineBasicBlock *Fallthrough,
1065-
bool FallthroughUnreachable,
1066-
BranchProbability UnhandledProbs,
1067-
MachineBasicBlock *CurMBB,
1068-
MachineIRBuilder &MIB,
1069-
MachineBasicBlock *SwitchMBB) {
1062+
bool IRTranslator::lowerSwitchAndOrRangeWorkItem(
1063+
SwitchCG::CaseClusterIt I, Value *Cond, MachineBasicBlock *Fallthrough,
1064+
bool FallthroughUnreachable, BranchProbability UnhandledProbs,
1065+
MachineBasicBlock *CurMBB, MachineIRBuilder &MIB,
1066+
MachineBasicBlock *SwitchMBB) {
10701067
using namespace SwitchCG;
10711068
const Value *RHS, *LHS, *MHS;
10721069
CmpInst::Predicate Pred;
1073-
if (I->Low == I->High) {
1070+
if (I->Low == I->High || I->Kind == CC_And) {
10741071
// Check Cond == I->Low.
10751072
Pred = CmpInst::ICMP_EQ;
10761073
LHS = Cond;
@@ -1088,6 +1085,7 @@ bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
10881085
// The false probability is the sum of all unhandled cases.
10891086
CaseBlock CB(Pred, FallthroughUnreachable, LHS, RHS, MHS, I->MBB, Fallthrough,
10901087
CurMBB, MIB.getDebugLoc(), I->Prob, UnhandledProbs);
1088+
CB.EmitAnd = I->Kind == CC_And;
10911089

10921090
emitSwitchCase(CB, SwitchMBB, MIB);
10931091
return true;
@@ -1327,10 +1325,11 @@ bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W,
13271325
}
13281326
break;
13291327
}
1328+
case CC_And:
13301329
case CC_Range: {
1331-
if (!lowerSwitchRangeWorkItem(I, Cond, Fallthrough,
1332-
FallthroughUnreachable, UnhandledProbs,
1333-
CurMBB, MIB, SwitchMBB)) {
1330+
if (!lowerSwitchAndOrRangeWorkItem(I, Cond, Fallthrough,
1331+
FallthroughUnreachable, UnhandledProbs,
1332+
CurMBB, MIB, SwitchMBB)) {
13341333
LLVM_DEBUG(dbgs() << "Failed to lower switch range");
13351334
return false;
13361335
}

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2857,7 +2857,17 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
28572857
EVT MemVT = TLI.getMemValueType(DAG.getDataLayout(), CB.CmpLHS->getType());
28582858

28592859
// Build the setcc now.
2860-
if (!CB.CmpMHS) {
2860+
if (CB.EmitAnd) {
2861+
SDLoc dl = getCurSDLoc();
2862+
2863+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2864+
EVT VT = TLI.getValueType(DAG.getDataLayout(), CB.CmpRHS->getType(), true);
2865+
SDValue C = DAG.getConstant(*cast<ConstantInt>(CB.CmpRHS), dl, VT);
2866+
SDValue Zero = DAG.getConstant(0, dl, VT);
2867+
SDValue CondLHS = getValue(CB.CmpLHS);
2868+
SDValue And = DAG.getNode(ISD::AND, dl, C.getValueType(), CondLHS, C);
2869+
Cond = DAG.getSetCC(dl, MVT::i1, And, Zero, ISD::SETEQ);
2870+
} else if (!CB.CmpMHS) {
28612871
// Fold "(X == true)" to X and "(X == false)" to !X to
28622872
// handle common cases produced by branch lowering.
28632873
if (CB.CmpRHS == ConstantInt::getTrue(*DAG.getContext()) &&
@@ -12248,10 +12258,11 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
1224812258
}
1224912259
break;
1225012260
}
12261+
case CC_And:
1225112262
case CC_Range: {
1225212263
const Value *RHS, *LHS, *MHS;
1225312264
ISD::CondCode CC;
12254-
if (I->Low == I->High) {
12265+
if (I->Low == I->High || I->Kind == CC_And) {
1225512266
// Check Cond == I->Low.
1225612267
CC = ISD::SETEQ;
1225712268
LHS = Cond;
@@ -12273,6 +12284,7 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
1227312284
CaseBlock CB(CC, LHS, RHS, MHS, I->MBB, Fallthrough, CurMBB,
1227412285
getCurSDLoc(), I->Prob, UnhandledProbs);
1227512286

12287+
CB.EmitAnd = I->Kind == CC_And;
1227612288
if (CurMBB == SwitchMBB)
1227712289
visitSwitchCase(CB, SwitchMBB);
1227812290
else

llvm/lib/CodeGen/SwitchLoweringUtils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,41 @@ void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
362362
}
363363
}
364364
Clusters.resize(DstIndex);
365+
366+
// Check if the clusters contain one checking for 0 and another one checking
367+
// for a power-of-2 constant with matching destinations. Those clusters can be
368+
// combined to a single ane with CC_And.
369+
unsigned ZeroIdx = -1;
370+
for (const auto &[Idx, C] : enumerate(Clusters)) {
371+
if (C.Kind != CC_Range || C.Low != C.High)
372+
continue;
373+
if (C.Low->isZero()) {
374+
ZeroIdx = Idx;
375+
break;
376+
}
377+
}
378+
if (ZeroIdx == -1u)
379+
return;
380+
381+
unsigned Pow2Idx = -1;
382+
for (const auto &[Idx, C] : enumerate(Clusters)) {
383+
if (C.Kind != CC_Range || C.Low != C.High || C.MBB != Clusters[ZeroIdx].MBB)
384+
continue;
385+
if (C.Low->getValue().isPowerOf2()) {
386+
Pow2Idx = Idx;
387+
break;
388+
}
389+
}
390+
if (Pow2Idx == -1u)
391+
return;
392+
393+
APInt Pow2 = Clusters[Pow2Idx].Low->getValue();
394+
APInt NewC = (Pow2 + 1) * -1;
395+
Clusters[ZeroIdx].Low = ConstantInt::get(SI->getContext(), NewC);
396+
Clusters[ZeroIdx].High = ConstantInt::get(SI->getContext(), NewC);
397+
Clusters[ZeroIdx].Kind = CC_And;
398+
Clusters[ZeroIdx].Prob += Clusters[Pow2Idx].Prob;
399+
Clusters.erase(Clusters.begin() + Pow2Idx);
365400
}
366401

367402
bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,

0 commit comments

Comments
 (0)