Skip to content

Commit 5932fcc

Browse files
authored
[InlineCost] Consider the default branch when calculating cost (#77856)
First step in fixing #76772. This PR considers the default branch as a case branch. This will give the unreachable default branch fair consideration.
1 parent 5aec939 commit 5932fcc

File tree

5 files changed

+555
-8
lines changed

5 files changed

+555
-8
lines changed

llvm/include/llvm/Analysis/InlineModelFeatureMaps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ namespace llvm {
3939
M(int64_t, {1}, jump_table_penalty, "Accumulation of costs for jump tables") \
4040
M(int64_t, {1}, case_cluster_penalty, \
4141
"Accumulation of costs for case clusters") \
42+
M(int64_t, {1}, switch_default_dest_penalty, \
43+
"Accumulation of costs for switch default destination") \
4244
M(int64_t, {1}, switch_penalty, \
4345
"Accumulation of costs for switch statements") \
4446
M(int64_t, {1}, unsimplified_common_instructions, \

llvm/include/llvm/IR/Instructions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class DataLayout;
4949
class StringRef;
5050
class Type;
5151
class Value;
52+
class UnreachableInst;
5253

5354
//===----------------------------------------------------------------------===//
5455
// AllocaInst Class
@@ -3505,6 +3506,12 @@ class SwitchInst : public Instruction {
35053506
return cast<BasicBlock>(getOperand(1));
35063507
}
35073508

3509+
/// Returns true if the default branch must result in immediate undefined
3510+
/// behavior, false otherwise.
3511+
bool defaultDestUndefined() const {
3512+
return isa<UnreachableInst>(getDefaultDest()->getFirstNonPHIOrDbg());
3513+
}
3514+
35083515
void setDefaultDest(BasicBlock *DefaultCase) {
35093516
setOperand(1, reinterpret_cast<Value*>(DefaultCase));
35103517
}

llvm/lib/Analysis/InlineCost.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
336336

337337
/// Called at the end of processing a switch instruction, with the given
338338
/// number of case clusters.
339-
virtual void onFinalizeSwitch(unsigned JumpTableSize,
340-
unsigned NumCaseCluster) {}
339+
virtual void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
340+
bool DefaultDestUndefined) {}
341341

342342
/// Called to account for any other instruction not specifically accounted
343343
/// for.
@@ -699,15 +699,16 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
699699
CallPenalty));
700700
}
701701

702-
void onFinalizeSwitch(unsigned JumpTableSize,
703-
unsigned NumCaseCluster) override {
702+
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
703+
bool DefaultDestUndefined) override {
704+
if (!DefaultDestUndefined)
705+
addCost(2 * InstrCost);
704706
// If suitable for a jump table, consider the cost for the table size and
705707
// branch to destination.
706708
// Maximum valid cost increased in this function.
707709
if (JumpTableSize) {
708710
int64_t JTCost =
709711
static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
710-
711712
addCost(JTCost);
712713
return;
713714
}
@@ -1153,6 +1154,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
11531154
// heuristics in the ML inliner.
11541155
static constexpr int JTCostMultiplier = 4;
11551156
static constexpr int CaseClusterCostMultiplier = 2;
1157+
static constexpr int SwitchDefaultDestCostMultiplier = 2;
11561158
static constexpr int SwitchCostMultiplier = 2;
11571159

11581160
// FIXME: These are taken from the heuristic-based cost visitor: we should
@@ -1231,8 +1233,11 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
12311233
}
12321234
}
12331235

1234-
void onFinalizeSwitch(unsigned JumpTableSize,
1235-
unsigned NumCaseCluster) override {
1236+
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
1237+
bool DefaultDestUndefined) override {
1238+
if (!DefaultDestUndefined)
1239+
increment(InlineCostFeatureIndex::switch_default_dest_penalty,
1240+
SwitchDefaultDestCostMultiplier * InstrCost);
12361241

12371242
if (JumpTableSize) {
12381243
int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
@@ -2461,7 +2466,7 @@ bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) {
24612466
unsigned NumCaseCluster =
24622467
TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize, PSI, BFI);
24632468

2464-
onFinalizeSwitch(JumpTableSize, NumCaseCluster);
2469+
onFinalizeSwitch(JumpTableSize, NumCaseCluster, SI.defaultDestUndefined());
24652470
return false;
24662471
}
24672472

0 commit comments

Comments
 (0)