Skip to content

Commit 535d8e8

Browse files
committed
NFC: Extract switch lowering binary tree splitting code from DAG into SwitchLoweringUtils.
This will help re-use this code with the upcoming GlobalISel implementation of this optimization.
1 parent b306a9c commit 535d8e8

File tree

4 files changed

+99
-84
lines changed

4 files changed

+99
-84
lines changed

llvm/include/llvm/CodeGen/SwitchLoweringUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,22 @@ class SwitchLowering {
293293
MachineBasicBlock *Src, MachineBasicBlock *Dst,
294294
BranchProbability Prob = BranchProbability::getUnknown()) = 0;
295295

296+
/// Determine the rank by weight of CC in [First,Last]. If CC has more weight
297+
/// than each cluster in the range, its rank is 0.
298+
unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
299+
CaseClusterIt Last);
300+
301+
struct SplitWorkItemInfo {
302+
CaseClusterIt LastLeft;
303+
CaseClusterIt FirstRight;
304+
BranchProbability LeftProb;
305+
BranchProbability RightProb;
306+
};
307+
/// Compute information to balance the tree based on branch probabilities to
308+
/// create a near-optimal (in terms of search time given key frequency) binary
309+
/// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees"
310+
/// (1975).
311+
SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W);
296312
virtual ~SwitchLowering() = default;
297313

298314
private:

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11639,92 +11639,16 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
1163911639
}
1164011640
}
1164111641

11642-
unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC,
11643-
CaseClusterIt First,
11644-
CaseClusterIt Last) {
11645-
return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
11646-
if (X.Prob != CC.Prob)
11647-
return X.Prob > CC.Prob;
11648-
11649-
// Ties are broken by comparing the case value.
11650-
return X.Low->getValue().slt(CC.Low->getValue());
11651-
});
11652-
}
11653-
1165411642
void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
1165511643
const SwitchWorkListItem &W,
1165611644
Value *Cond,
1165711645
MachineBasicBlock *SwitchMBB) {
1165811646
assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) &&
1165911647
"Clusters not sorted?");
11660-
1166111648
assert(W.LastCluster - W.FirstCluster + 1 >= 2 && "Too small to split!");
1166211649

11663-
// Balance the tree based on branch probabilities to create a near-optimal (in
11664-
// terms of search time given key frequency) binary search tree. See e.g. Kurt
11665-
// Mehlhorn "Nearly Optimal Binary Search Trees" (1975).
11666-
CaseClusterIt LastLeft = W.FirstCluster;
11667-
CaseClusterIt FirstRight = W.LastCluster;
11668-
auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
11669-
auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
11670-
11671-
// Move LastLeft and FirstRight towards each other from opposite directions to
11672-
// find a partitioning of the clusters which balances the probability on both
11673-
// sides. If LeftProb and RightProb are equal, alternate which side is
11674-
// taken to ensure 0-probability nodes are distributed evenly.
11675-
unsigned I = 0;
11676-
while (LastLeft + 1 < FirstRight) {
11677-
if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
11678-
LeftProb += (++LastLeft)->Prob;
11679-
else
11680-
RightProb += (--FirstRight)->Prob;
11681-
I++;
11682-
}
11683-
11684-
while (true) {
11685-
// Our binary search tree differs from a typical BST in that ours can have up
11686-
// to three values in each leaf. The pivot selection above doesn't take that
11687-
// into account, which means the tree might require more nodes and be less
11688-
// efficient. We compensate for this here.
11689-
11690-
unsigned NumLeft = LastLeft - W.FirstCluster + 1;
11691-
unsigned NumRight = W.LastCluster - FirstRight + 1;
11692-
11693-
if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
11694-
// If one side has less than 3 clusters, and the other has more than 3,
11695-
// consider taking a cluster from the other side.
11696-
11697-
if (NumLeft < NumRight) {
11698-
// Consider moving the first cluster on the right to the left side.
11699-
CaseCluster &CC = *FirstRight;
11700-
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
11701-
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
11702-
if (LeftSideRank <= RightSideRank) {
11703-
// Moving the cluster to the left does not demote it.
11704-
++LastLeft;
11705-
++FirstRight;
11706-
continue;
11707-
}
11708-
} else {
11709-
assert(NumRight < NumLeft);
11710-
// Consider moving the last element on the left to the right side.
11711-
CaseCluster &CC = *LastLeft;
11712-
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
11713-
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
11714-
if (RightSideRank <= LeftSideRank) {
11715-
// Moving the cluster to the right does not demot it.
11716-
--LastLeft;
11717-
--FirstRight;
11718-
continue;
11719-
}
11720-
}
11721-
}
11722-
break;
11723-
}
11724-
11725-
assert(LastLeft + 1 == FirstRight);
11726-
assert(LastLeft >= W.FirstCluster);
11727-
assert(FirstRight <= W.LastCluster);
11650+
auto [LastLeft, FirstRight, LeftProb, RightProb] =
11651+
SL->computeSplitWorkItemInfo(W);
1172811652

1172911653
// Use the first element on the right as pivot since we will make less-than
1173011654
// comparisons against it.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,6 @@ class SelectionDAGBuilder {
200200
/// create.
201201
unsigned SDNodeOrder;
202202

203-
/// Determine the rank by weight of CC in [First,Last]. If CC has more weight
204-
/// than each cluster in the range, its rank is 0.
205-
unsigned caseClusterRank(const SwitchCG::CaseCluster &CC,
206-
SwitchCG::CaseClusterIt First,
207-
SwitchCG::CaseClusterIt Last);
208-
209203
/// Emit comparison and split W into two subtrees.
210204
void splitWorkItem(SwitchCG::SwitchWorkList &WorkList,
211205
const SwitchCG::SwitchWorkListItem &W, Value *Cond,

llvm/lib/CodeGen/SwitchLoweringUtils.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,84 @@ void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
494494
}
495495
Clusters.resize(DstIndex);
496496
}
497+
498+
unsigned SwitchCG::SwitchLowering::caseClusterRank(const CaseCluster &CC,
499+
CaseClusterIt First,
500+
CaseClusterIt Last) {
501+
return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
502+
if (X.Prob != CC.Prob)
503+
return X.Prob > CC.Prob;
504+
505+
// Ties are broken by comparing the case value.
506+
return X.Low->getValue().slt(CC.Low->getValue());
507+
});
508+
}
509+
510+
llvm::SwitchCG::SwitchLowering::SplitWorkItemInfo
511+
SwitchCG::SwitchLowering::computeSplitWorkItemInfo(
512+
const SwitchWorkListItem &W) {
513+
CaseClusterIt LastLeft = W.FirstCluster;
514+
CaseClusterIt FirstRight = W.LastCluster;
515+
auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
516+
auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
517+
518+
// Move LastLeft and FirstRight towards each other from opposite directions to
519+
// find a partitioning of the clusters which balances the probability on both
520+
// sides. If LeftProb and RightProb are equal, alternate which side is
521+
// taken to ensure 0-probability nodes are distributed evenly.
522+
unsigned I = 0;
523+
while (LastLeft + 1 < FirstRight) {
524+
if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
525+
LeftProb += (++LastLeft)->Prob;
526+
else
527+
RightProb += (--FirstRight)->Prob;
528+
I++;
529+
}
530+
531+
while (true) {
532+
// Our binary search tree differs from a typical BST in that ours can have
533+
// up to three values in each leaf. The pivot selection above doesn't take
534+
// that into account, which means the tree might require more nodes and be
535+
// less efficient. We compensate for this here.
536+
537+
unsigned NumLeft = LastLeft - W.FirstCluster + 1;
538+
unsigned NumRight = W.LastCluster - FirstRight + 1;
539+
540+
if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
541+
// If one side has less than 3 clusters, and the other has more than 3,
542+
// consider taking a cluster from the other side.
543+
544+
if (NumLeft < NumRight) {
545+
// Consider moving the first cluster on the right to the left side.
546+
CaseCluster &CC = *FirstRight;
547+
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
548+
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
549+
if (LeftSideRank <= RightSideRank) {
550+
// Moving the cluster to the left does not demote it.
551+
++LastLeft;
552+
++FirstRight;
553+
continue;
554+
}
555+
} else {
556+
assert(NumRight < NumLeft);
557+
// Consider moving the last element on the left to the right side.
558+
CaseCluster &CC = *LastLeft;
559+
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
560+
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
561+
if (RightSideRank <= LeftSideRank) {
562+
// Moving the cluster to the right does not demot it.
563+
--LastLeft;
564+
--FirstRight;
565+
continue;
566+
}
567+
}
568+
}
569+
break;
570+
}
571+
572+
assert(LastLeft + 1 == FirstRight);
573+
assert(LastLeft >= W.FirstCluster);
574+
assert(FirstRight <= W.LastCluster);
575+
576+
return SplitWorkItemInfo{LastLeft, FirstRight, LeftProb, RightProb};
577+
}

0 commit comments

Comments
 (0)