Skip to content

[SwitchLowering] Support merging 0 and power-of-2 case. #139736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,13 @@ class IRTranslator : public MachineFunctionPass {
BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I,
MachineBasicBlock *Fallthrough, bool FallthroughUnreachable);

bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
MachineBasicBlock *Fallthrough,
bool FallthroughUnreachable,
BranchProbability UnhandledProbs,
MachineBasicBlock *CurMBB,
MachineIRBuilder &MIB,
MachineBasicBlock *SwitchMBB);
bool lowerSwitchAndOrRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
MachineBasicBlock *Fallthrough,
bool FallthroughUnreachable,
BranchProbability UnhandledProbs,
MachineBasicBlock *CurMBB,
MachineIRBuilder &MIB,
MachineBasicBlock *SwitchMBB);

bool lowerBitTestWorkItem(
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,
Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ enum CaseClusterKind {
/// A cluster of cases suitable for jump table lowering.
CC_JumpTable,
/// A cluster of cases suitable for bit test lowering.
CC_BitTests
CC_BitTests,
CC_And
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deserves a short comment explaining when it can be used, like the others.

};

/// A cluster of case labels.
Expand Down Expand Up @@ -141,6 +142,8 @@ struct CaseBlock {
BranchProbability TrueProb, FalseProb;
bool IsUnpredictable;

bool EmitAnd = false;

// Constructor for SelectionDAG.
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
const Value *cmpmiddle, MachineBasicBlock *truebb,
Expand Down
32 changes: 19 additions & 13 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,14 @@ void IRTranslator::emitSwitchCase(SwitchCG::CaseBlock &CB,

const LLT i1Ty = LLT::scalar(1);
// Build the compare.
if (!CB.CmpMHS) {
if (CB.EmitAnd) {
const LLT Ty = getLLTForType(*CB.CmpRHS->getType(), *DL);
Register CondLHS = getOrCreateVReg(*CB.CmpLHS);
Register C = getOrCreateVReg(*CB.CmpRHS);
Register And = MIB.buildAnd(Ty, CondLHS, C).getReg(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Register And = MIB.buildAnd(Ty, CondLHS, C).getReg(0);
auto And = MIB.buildAnd(ty, CondLHS, C);

auto Zero = MIB.buildConstant(Ty, 0);
Cond = MIB.buildICmp(CmpInst::ICMP_EQ, i1Ty, And, Zero).getReg(0);
} else if (!CB.CmpMHS) {
const auto *CI = dyn_cast<ConstantInt>(CB.CmpRHS);
// For conditional branch lowering, we might try to do something silly like
// emit an G_ICMP to compare an existing G_ICMP i1 result with true. If so,
Expand Down Expand Up @@ -1059,18 +1066,15 @@ bool IRTranslator::lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
}
return true;
}
bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
Value *Cond,
MachineBasicBlock *Fallthrough,
bool FallthroughUnreachable,
BranchProbability UnhandledProbs,
MachineBasicBlock *CurMBB,
MachineIRBuilder &MIB,
MachineBasicBlock *SwitchMBB) {
bool IRTranslator::lowerSwitchAndOrRangeWorkItem(
SwitchCG::CaseClusterIt I, Value *Cond, MachineBasicBlock *Fallthrough,
bool FallthroughUnreachable, BranchProbability UnhandledProbs,
MachineBasicBlock *CurMBB, MachineIRBuilder &MIB,
MachineBasicBlock *SwitchMBB) {
using namespace SwitchCG;
const Value *RHS, *LHS, *MHS;
CmpInst::Predicate Pred;
if (I->Low == I->High) {
if (I->Low == I->High || I->Kind == CC_And) {
// Check Cond == I->Low.
Pred = CmpInst::ICMP_EQ;
LHS = Cond;
Expand All @@ -1088,6 +1092,7 @@ bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
// The false probability is the sum of all unhandled cases.
CaseBlock CB(Pred, FallthroughUnreachable, LHS, RHS, MHS, I->MBB, Fallthrough,
CurMBB, MIB.getDebugLoc(), I->Prob, UnhandledProbs);
CB.EmitAnd = I->Kind == CC_And;

emitSwitchCase(CB, SwitchMBB, MIB);
return true;
Expand Down Expand Up @@ -1327,10 +1332,11 @@ bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W,
}
break;
}
case CC_And:
case CC_Range: {
if (!lowerSwitchRangeWorkItem(I, Cond, Fallthrough,
FallthroughUnreachable, UnhandledProbs,
CurMBB, MIB, SwitchMBB)) {
if (!lowerSwitchAndOrRangeWorkItem(I, Cond, Fallthrough,
FallthroughUnreachable, UnhandledProbs,
CurMBB, MIB, SwitchMBB)) {
LLVM_DEBUG(dbgs() << "Failed to lower switch range");
return false;
}
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2857,7 +2857,17 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
EVT MemVT = TLI.getMemValueType(DAG.getDataLayout(), CB.CmpLHS->getType());

// Build the setcc now.
if (!CB.CmpMHS) {
if (CB.EmitAnd) {
SDLoc dl = getCurSDLoc();

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT VT = TLI.getValueType(DAG.getDataLayout(), CB.CmpRHS->getType(), true);
SDValue C = DAG.getConstant(*cast<ConstantInt>(CB.CmpRHS), dl, VT);
SDValue Zero = DAG.getConstant(0, dl, VT);
SDValue CondLHS = getValue(CB.CmpLHS);
SDValue And = DAG.getNode(ISD::AND, dl, C.getValueType(), CondLHS, C);
Cond = DAG.getSetCC(dl, MVT::i1, And, Zero, ISD::SETEQ);
} else if (!CB.CmpMHS) {
// Fold "(X == true)" to X and "(X == false)" to !X to
// handle common cases produced by branch lowering.
if (CB.CmpRHS == ConstantInt::getTrue(*DAG.getContext()) &&
Expand Down Expand Up @@ -12248,6 +12258,7 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
}
break;
}
case CC_And:
case CC_Range: {
const Value *RHS, *LHS, *MHS;
ISD::CondCode CC;
Expand All @@ -12259,6 +12270,7 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
MHS = nullptr;
} else {
// Check I->Low <= Cond <= I->High.
assert(I->Kind != CC_And && "CC_And must be handled above");
CC = ISD::SETLE;
LHS = I->Low;
MHS = Cond;
Expand All @@ -12273,6 +12285,7 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
CaseBlock CB(CC, LHS, RHS, MHS, I->MBB, Fallthrough, CurMBB,
getCurSDLoc(), I->Prob, UnhandledProbs);

CB.EmitAnd = I->Kind == CC_And;
if (CurMBB == SwitchMBB)
visitSwitchCase(CB, SwitchMBB);
else
Expand Down
35 changes: 35 additions & 0 deletions llvm/lib/CodeGen/SwitchLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,41 @@ void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
}
}
Clusters.resize(DstIndex);

// Check if the clusters contain one checking for 0 and another one checking
// for a power-of-2 constant with matching destinations. Those clusters can be
// combined to a single one with CC_And.
unsigned ZeroIdx = -1;
for (const auto &[Idx, C] : enumerate(Clusters)) {
if (C.Kind != CC_Range || C.Low != C.High)
continue;
if (C.Low->isZero()) {
ZeroIdx = Idx;
break;
}
}
if (ZeroIdx == -1u)
return;

unsigned Pow2Idx = -1;
for (const auto &[Idx, C] : enumerate(Clusters)) {
if (C.Kind != CC_Range || C.Low != C.High || C.MBB != Clusters[ZeroIdx].MBB)
continue;
if (C.Low->getValue().isPowerOf2()) {
Pow2Idx = Idx;
break;
}
}
if (Pow2Idx == -1u)
return;

APInt Pow2 = Clusters[Pow2Idx].Low->getValue();
APInt NewC = ~Pow2;
Clusters[ZeroIdx].Low = ConstantInt::get(SI->getContext(), NewC);
Clusters[ZeroIdx].High = ConstantInt::get(SI->getContext(), NewC);
Clusters[ZeroIdx].Kind = CC_And;
Clusters[ZeroIdx].Prob += Clusters[Pow2Idx].Prob;
Clusters.erase(Clusters.begin() + Pow2Idx);
}

bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
Expand Down
Loading
Loading