Skip to content

Commit 2ff1495

Browse files
committed
[GlobalISel] Implement bit-test switch table optimization.
This is mostly a straight port from SelectionDAG. We re-use the actual bit-test analysis part from SwitchLoweringUtils, which was factored out earlier to support jump-tables. Differential Revision: https://reviews.llvm.org/D85233
1 parent 39de63a commit 2ff1495

File tree

5 files changed

+419
-19
lines changed

5 files changed

+419
-19
lines changed

llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -312,26 +312,38 @@ class IRTranslator : public MachineFunctionPass {
312312
void emitSwitchCase(SwitchCG::CaseBlock &CB, MachineBasicBlock *SwitchBB,
313313
MachineIRBuilder &MIB);
314314

315-
bool lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
316-
MachineBasicBlock *SwitchMBB,
317-
MachineBasicBlock *CurMBB,
318-
MachineBasicBlock *DefaultMBB,
319-
MachineIRBuilder &MIB,
320-
MachineFunction::iterator BBI,
321-
BranchProbability UnhandledProbs,
322-
SwitchCG::CaseClusterIt I,
323-
MachineBasicBlock *Fallthrough,
324-
bool FallthroughUnreachable);
325-
326-
bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
327-
Value *Cond,
315+
/// Generate for for the BitTest header block, which precedes each sequence of
316+
/// BitTestCases.
317+
void emitBitTestHeader(SwitchCG::BitTestBlock &BTB,
318+
MachineBasicBlock *SwitchMBB);
319+
/// Generate code to produces one "bit test" for a given BitTestCase \p B.
320+
void emitBitTestCase(SwitchCG::BitTestBlock &BB, MachineBasicBlock *NextMBB,
321+
BranchProbability BranchProbToNext, Register Reg,
322+
SwitchCG::BitTestCase &B, MachineBasicBlock *SwitchBB);
323+
324+
bool lowerJumpTableWorkItem(
325+
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,
326+
MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB,
327+
MachineIRBuilder &MIB, MachineFunction::iterator BBI,
328+
BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I,
329+
MachineBasicBlock *Fallthrough, bool FallthroughUnreachable);
330+
331+
bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
328332
MachineBasicBlock *Fallthrough,
329333
bool FallthroughUnreachable,
330334
BranchProbability UnhandledProbs,
331335
MachineBasicBlock *CurMBB,
332336
MachineIRBuilder &MIB,
333337
MachineBasicBlock *SwitchMBB);
334338

339+
bool lowerBitTestWorkItem(
340+
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,
341+
MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB,
342+
MachineIRBuilder &MIB, MachineFunction::iterator BBI,
343+
BranchProbability DefaultProb, BranchProbability UnhandledProbs,
344+
SwitchCG::CaseClusterIt I, MachineBasicBlock *Fallthrough,
345+
bool FallthroughUnreachable);
346+
335347
bool lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W, Value *Cond,
336348
MachineBasicBlock *SwitchMBB,
337349
MachineBasicBlock *DefaultMBB,

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class MachineIRBuilder {
730730
/// depend on bit 0 (for now).
731731
///
732732
/// \return The newly created instruction.
733-
MachineInstrBuilder buildBrCond(Register Tst, MachineBasicBlock &Dest);
733+
MachineInstrBuilder buildBrCond(const SrcOp &Tst, MachineBasicBlock &Dest);
734734

735735
/// Build and insert G_BRINDIRECT \p Tgt
736736
///

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 210 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ bool IRTranslator::translateSwitch(const User &U, MachineIRBuilder &MIB) {
446446
}
447447

448448
SL->findJumpTables(Clusters, &SI, DefaultMBB, nullptr, nullptr);
449+
SL->findBitTestClusters(Clusters, &SI);
449450

450451
LLVM_DEBUG({
451452
dbgs() << "Case clusters: ";
@@ -723,6 +724,156 @@ bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
723724
return true;
724725
}
725726

727+
void IRTranslator::emitBitTestHeader(SwitchCG::BitTestBlock &B,
728+
MachineBasicBlock *SwitchBB) {
729+
MachineIRBuilder &MIB = *CurBuilder;
730+
MIB.setMBB(*SwitchBB);
731+
732+
// Subtract the minimum value.
733+
Register SwitchOpReg = getOrCreateVReg(*B.SValue);
734+
735+
LLT SwitchOpTy = MRI->getType(SwitchOpReg);
736+
Register MinValReg = MIB.buildConstant(SwitchOpTy, B.First).getReg(0);
737+
auto RangeSub = MIB.buildSub(SwitchOpTy, SwitchOpReg, MinValReg);
738+
739+
// Ensure that the type will fit the mask value.
740+
LLT MaskTy = SwitchOpTy;
741+
for (unsigned I = 0, E = B.Cases.size(); I != E; ++I) {
742+
if (!isUIntN(SwitchOpTy.getSizeInBits(), B.Cases[I].Mask)) {
743+
// Switch table case range are encoded into series of masks.
744+
// Just use pointer type, it's guaranteed to fit.
745+
MaskTy = LLT::scalar(64);
746+
break;
747+
}
748+
}
749+
Register SubReg = RangeSub.getReg(0);
750+
if (SwitchOpTy != MaskTy)
751+
SubReg = MIB.buildZExtOrTrunc(MaskTy, SubReg).getReg(0);
752+
753+
B.RegVT = getMVTForLLT(MaskTy);
754+
B.Reg = SubReg;
755+
756+
MachineBasicBlock *MBB = B.Cases[0].ThisBB;
757+
758+
if (!B.OmitRangeCheck)
759+
addSuccessorWithProb(SwitchBB, B.Default, B.DefaultProb);
760+
addSuccessorWithProb(SwitchBB, MBB, B.Prob);
761+
762+
SwitchBB->normalizeSuccProbs();
763+
764+
if (!B.OmitRangeCheck) {
765+
// Conditional branch to the default block.
766+
auto RangeCst = MIB.buildConstant(SwitchOpTy, B.Range);
767+
auto RangeCmp = MIB.buildICmp(CmpInst::Predicate::ICMP_UGT, LLT::scalar(1),
768+
RangeSub, RangeCst);
769+
MIB.buildBrCond(RangeCmp, *B.Default);
770+
}
771+
772+
// Avoid emitting unnecessary branches to the next block.
773+
if (MBB != SwitchBB->getNextNode())
774+
MIB.buildBr(*MBB);
775+
}
776+
777+
void IRTranslator::emitBitTestCase(SwitchCG::BitTestBlock &BB,
778+
MachineBasicBlock *NextMBB,
779+
BranchProbability BranchProbToNext,
780+
Register Reg, SwitchCG::BitTestCase &B,
781+
MachineBasicBlock *SwitchBB) {
782+
MachineIRBuilder &MIB = *CurBuilder;
783+
MIB.setMBB(*SwitchBB);
784+
785+
LLT SwitchTy = getLLTForMVT(BB.RegVT);
786+
Register Cmp;
787+
unsigned PopCount = countPopulation(B.Mask);
788+
if (PopCount == 1) {
789+
// Testing for a single bit; just compare the shift count with what it
790+
// would need to be to shift a 1 bit in that position.
791+
auto MaskTrailingZeros =
792+
MIB.buildConstant(SwitchTy, countTrailingZeros(B.Mask));
793+
Cmp =
794+
MIB.buildICmp(ICmpInst::ICMP_EQ, LLT::scalar(1), Reg, MaskTrailingZeros)
795+
.getReg(0);
796+
} else if (PopCount == BB.Range) {
797+
// There is only one zero bit in the range, test for it directly.
798+
auto MaskTrailingOnes =
799+
MIB.buildConstant(SwitchTy, countTrailingOnes(B.Mask));
800+
Cmp = MIB.buildICmp(CmpInst::ICMP_NE, LLT::scalar(1), Reg, MaskTrailingOnes)
801+
.getReg(0);
802+
} else {
803+
// Make desired shift.
804+
auto CstOne = MIB.buildConstant(SwitchTy, 1);
805+
auto SwitchVal = MIB.buildShl(SwitchTy, CstOne, Reg);
806+
807+
// Emit bit tests and jumps.
808+
auto CstMask = MIB.buildConstant(SwitchTy, B.Mask);
809+
auto AndOp = MIB.buildAnd(SwitchTy, SwitchVal, CstMask);
810+
auto CstZero = MIB.buildConstant(SwitchTy, 0);
811+
Cmp = MIB.buildICmp(CmpInst::ICMP_NE, LLT::scalar(1), AndOp, CstZero)
812+
.getReg(0);
813+
}
814+
815+
// The branch probability from SwitchBB to B.TargetBB is B.ExtraProb.
816+
addSuccessorWithProb(SwitchBB, B.TargetBB, B.ExtraProb);
817+
// The branch probability from SwitchBB to NextMBB is BranchProbToNext.
818+
addSuccessorWithProb(SwitchBB, NextMBB, BranchProbToNext);
819+
// It is not guaranteed that the sum of B.ExtraProb and BranchProbToNext is
820+
// one as they are relative probabilities (and thus work more like weights),
821+
// and hence we need to normalize them to let the sum of them become one.
822+
SwitchBB->normalizeSuccProbs();
823+
824+
// Record the fact that the IR edge from the header to the bit test target
825+
// will go through our new block. Neeeded for PHIs to have nodes added.
826+
addMachineCFGPred({BB.Parent->getBasicBlock(), B.TargetBB->getBasicBlock()},
827+
SwitchBB);
828+
829+
MIB.buildBrCond(Cmp, *B.TargetBB);
830+
831+
// Avoid emitting unnecessary branches to the next block.
832+
if (NextMBB != SwitchBB->getNextNode())
833+
MIB.buildBr(*NextMBB);
834+
}
835+
836+
bool IRTranslator::lowerBitTestWorkItem(
837+
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,
838+
MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB,
839+
MachineIRBuilder &MIB, MachineFunction::iterator BBI,
840+
BranchProbability DefaultProb, BranchProbability UnhandledProbs,
841+
SwitchCG::CaseClusterIt I, MachineBasicBlock *Fallthrough,
842+
bool FallthroughUnreachable) {
843+
using namespace SwitchCG;
844+
MachineFunction *CurMF = SwitchMBB->getParent();
845+
// FIXME: Optimize away range check based on pivot comparisons.
846+
BitTestBlock *BTB = &SL->BitTestCases[I->BTCasesIndex];
847+
// The bit test blocks haven't been inserted yet; insert them here.
848+
for (BitTestCase &BTC : BTB->Cases)
849+
CurMF->insert(BBI, BTC.ThisBB);
850+
851+
// Fill in fields of the BitTestBlock.
852+
BTB->Parent = CurMBB;
853+
BTB->Default = Fallthrough;
854+
855+
BTB->DefaultProb = UnhandledProbs;
856+
// If the cases in bit test don't form a contiguous range, we evenly
857+
// distribute the probability on the edge to Fallthrough to two
858+
// successors of CurMBB.
859+
if (!BTB->ContiguousRange) {
860+
BTB->Prob += DefaultProb / 2;
861+
BTB->DefaultProb -= DefaultProb / 2;
862+
}
863+
864+
if (FallthroughUnreachable) {
865+
// Skip the range check if the fallthrough block is unreachable.
866+
BTB->OmitRangeCheck = true;
867+
}
868+
869+
// If we're in the right place, emit the bit test header right now.
870+
if (CurMBB == SwitchMBB) {
871+
emitBitTestHeader(*BTB, SwitchMBB);
872+
BTB->Emitted = true;
873+
}
874+
return true;
875+
}
876+
726877
bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W,
727878
Value *Cond,
728879
MachineBasicBlock *SwitchMBB,
@@ -783,9 +934,15 @@ bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W,
783934

784935
switch (I->Kind) {
785936
case CC_BitTests: {
786-
LLVM_DEBUG(dbgs() << "Switch to bit test optimization unimplemented");
787-
return false; // Bit tests currently unimplemented.
937+
if (!lowerBitTestWorkItem(W, SwitchMBB, CurMBB, DefaultMBB, MIB, BBI,
938+
DefaultProb, UnhandledProbs, I, Fallthrough,
939+
FallthroughUnreachable)) {
940+
LLVM_DEBUG(dbgs() << "Failed to lower bit test for switch");
941+
return false;
942+
}
943+
break;
788944
}
945+
789946
case CC_JumpTable: {
790947
if (!lowerJumpTableWorkItem(W, SwitchMBB, CurMBB, DefaultMBB, MIB, BBI,
791948
UnhandledProbs, I, Fallthrough,
@@ -2349,6 +2506,57 @@ bool IRTranslator::translate(const Constant &C, Register Reg) {
23492506
}
23502507

23512508
void IRTranslator::finalizeBasicBlock() {
2509+
for (auto &BTB : SL->BitTestCases) {
2510+
// Emit header first, if it wasn't already emitted.
2511+
if (!BTB.Emitted)
2512+
emitBitTestHeader(BTB, BTB.Parent);
2513+
2514+
BranchProbability UnhandledProb = BTB.Prob;
2515+
for (unsigned j = 0, ej = BTB.Cases.size(); j != ej; ++j) {
2516+
UnhandledProb -= BTB.Cases[j].ExtraProb;
2517+
// Set the current basic block to the mbb we wish to insert the code into
2518+
MachineBasicBlock *MBB = BTB.Cases[j].ThisBB;
2519+
// If all cases cover a contiguous range, it is not necessary to jump to
2520+
// the default block after the last bit test fails. This is because the
2521+
// range check during bit test header creation has guaranteed that every
2522+
// case here doesn't go outside the range. In this case, there is no need
2523+
// to perform the last bit test, as it will always be true. Instead, make
2524+
// the second-to-last bit-test fall through to the target of the last bit
2525+
// test, and delete the last bit test.
2526+
2527+
MachineBasicBlock *NextMBB;
2528+
if (BTB.ContiguousRange && j + 2 == ej) {
2529+
// Second-to-last bit-test with contiguous range: fall through to the
2530+
// target of the final bit test.
2531+
NextMBB = BTB.Cases[j + 1].TargetBB;
2532+
} else if (j + 1 == ej) {
2533+
// For the last bit test, fall through to Default.
2534+
NextMBB = BTB.Default;
2535+
} else {
2536+
// Otherwise, fall through to the next bit test.
2537+
NextMBB = BTB.Cases[j + 1].ThisBB;
2538+
}
2539+
2540+
emitBitTestCase(BTB, NextMBB, UnhandledProb, BTB.Reg, BTB.Cases[j], MBB);
2541+
2542+
// FIXME delete this block below?
2543+
if (BTB.ContiguousRange && j + 2 == ej) {
2544+
// Since we're not going to use the final bit test, remove it.
2545+
BTB.Cases.pop_back();
2546+
break;
2547+
}
2548+
}
2549+
// This is "default" BB. We have two jumps to it. From "header" BB and from
2550+
// last "case" BB, unless the latter was skipped.
2551+
CFGEdge HeaderToDefaultEdge = {BTB.Parent->getBasicBlock(),
2552+
BTB.Default->getBasicBlock()};
2553+
addMachineCFGPred(HeaderToDefaultEdge, BTB.Parent);
2554+
if (!BTB.ContiguousRange) {
2555+
addMachineCFGPred(HeaderToDefaultEdge, BTB.Cases.back().ThisBB);
2556+
}
2557+
}
2558+
SL->BitTestCases.clear();
2559+
23522560
for (auto &JTCase : SL->JTCases) {
23532561
// Emit header first, if it wasn't already emitted.
23542562
if (!JTCase.first.Emitted)

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,14 @@ MachineInstrBuilder MachineIRBuilder::buildFConstant(const DstOp &Res,
317317
return buildFConstant(Res, *CFP);
318318
}
319319

320-
MachineInstrBuilder MachineIRBuilder::buildBrCond(Register Tst,
320+
MachineInstrBuilder MachineIRBuilder::buildBrCond(const SrcOp &Tst,
321321
MachineBasicBlock &Dest) {
322-
assert(getMRI()->getType(Tst).isScalar() && "invalid operand type");
322+
assert(Tst.getLLTTy(*getMRI()).isScalar() && "invalid operand type");
323323

324-
return buildInstr(TargetOpcode::G_BRCOND).addUse(Tst).addMBB(&Dest);
324+
auto MIB = buildInstr(TargetOpcode::G_BRCOND);
325+
Tst.addSrcToMIB(MIB);
326+
MIB.addMBB(&Dest);
327+
return MIB;
325328
}
326329

327330
MachineInstrBuilder

0 commit comments

Comments
 (0)