-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[InlineCost] Correct the default branch cost for the switch statement #85160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2c29887
to
9a7e12c
Compare
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Quentin Dian (DianQK) ChangesI use the following patch to find functions that are not inlined after #77856. <details><summary>patch.diff</summary> diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e9..e325d18ab0a8 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -29,6 +29,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
+#include "llvm/Demangle/Demangle.h"
#include "llvm/IR/AssemblyAnnotationWriter.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/DataLayout.h"
@@ -575,6 +576,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// True if the cost-benefit-analysis-based inliner is enabled.
const bool CostBenefitAnalysisEnabled;
+ int DefaultBranchCost = 0;
+
/// Inlining cost measured in abstract units, accounts for all the
/// instructions expected to be executed for a given function invocation.
/// Instructions that are statically proven to be dead based on call-site
@@ -701,8 +704,11 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
bool DefaultDestUndefined) override {
- if (!DefaultDestUndefined)
+ if (!DefaultDestUndefined) {
+ DefaultBranchCost = std::clamp<int64_t>(DefaultBranchCost + 2 * InstrCost,
+ INT_MIN, INT_MAX);
addCost(2 * InstrCost);
+ }
// If suitable for a jump table, consider the cost for the table size and
// branch to destination.
// Maximum valid cost increased in this function.
@@ -1132,6 +1138,7 @@ public:
virtual ~InlineCostCallAnalyzer() = default;
int getThreshold() const { return Threshold; }
int getCost() const { return Cost; }
+ int getDefaultBranchCost() const { return DefaultBranchCost; }
int getStaticBonusApplied() const { return StaticBonusApplied; }
std::optional<CostBenefitPair> getCostBenefitPair() { return CostBenefit; }
bool wasDecidedByCostBenefit() const { return DecidedByCostBenefit; }
@@ -3072,6 +3079,23 @@ InlineCost llvm::getInlineCost(
GetAssumptionCache, GetBFI, PSI, ORE);
InlineResult ShouldInline = CA.analyze();
+ if (CA.getCost() > CA.getThreshold() &&
+ (CA.getCost() - CA.getDefaultBranchCost() <= CA.getThreshold())) {
+ auto ModuleName = Callee->getParent()->getName();
+ auto *CallerName = llvm::itaniumDemangle(Call.getCaller()->getName());
+ auto *CalleeName = llvm::itaniumDemangle(Callee->getName());
+ errs() << "NOT Inlining ModuleName: " << ModuleName << " Caller: " << CallerName
+ << ", Callee: " << CalleeName << ", Cost: " << CA.getCost()
+ << ", Threshold: " << CA.getThreshold()
+ << ", DefaultBranchCost: " << CA.getDefaultBranchCost();
+ if (auto *SP = Callee->getSubprogram()) {
+ auto FileName = SP->getFilename();
+ unsigned Line = SP->getLine();
+ errs() << ", FileName: " << FileName << "#L" << Line;
+ }
+ errs() << "\n";
+ }
+
LLVM_DEBUG(CA.dump());
// Always make cost benefit based decision explicit. </p> There are over 20,000 call sites that don't satisfy the inline condition. I tried to select 10 of them: <details><summary>Details</summary>
</p> There are complex switch statements that cannot be transformed to simpler structures. The earliest commit of the related code is: 919f9e8. I tried to understand the following code with #77856 (comment). llvm-project/llvm/lib/Analysis/InlineCost.cpp Lines 709 to 720 in 5932fcc
I think only scenarios where there is a default branch were considered. Taking https://llvm.godbolt.org/z/5cno1TnGx as an example, we need additional compare and jump instructions when there is a default branch, otherwise we just need a jump instruction. foo: # @<!-- -->foo
cmp rdi, 6
ja .LBB0_6
jmp qword ptr [8*rdi + .LJTI0_0]
...
bar: # @<!-- -->bar
jmp qword ptr [8*rdi + .LJTI1_0]
... But I don't know why it's Taking https://llvm.godbolt.org/z/MEsf9sno7 as an example, we can reduce a set of compare and jump instructions when the number of branches is small. foo: # @<!-- -->foo
cmp rdi, 4
je .LBB0_5
cmp rdi, 2
je .LBB0_4
test rdi, rdi
jne .LBB0_6
...
bar: # @<!-- -->bar
cmp rdi, 4
je .LBB1_4
cmp rdi, 2
jne .LBB1_2
... Further, I found that for scenarios where there are more branches. The generated compare instructions should be less than the number of branches if the default branch is undefined behavior. There will be fewer compare instructions if there are some common branches. Revert the result of #77856: https://llvm-compile-time-tracker.com/compare.php?from=f3c5278efa3b783ada9e7a34b751cf4c5b864535&to=58622ef6755a02f97e5127bea29ed5b8812fe25e&stat=instructions:u. Full diff: https://github.com/llvm/llvm-project/pull/85160.diff 3 Files Affected:
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e947..9d29d5765c1915 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -536,7 +536,13 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
// Considering comparisons from leaf and non-leaf nodes, we can estimate the
// number of comparisons in a simple closed form :
// n + n / 2 - 1 = n * 3 / 2 - 1
-int64_t getExpectedNumberOfCompare(int NumCaseCluster) {
+int64_t getExpectedNumberOfCompare(int NumCaseCluster,
+ bool DefaultDestUndefined) {
+ // The compare instruction count should be less than the branch count
+ // when default branch is undefined.
+ if (DefaultDestUndefined) {
+ return static_cast<int64_t>(NumCaseCluster) - 1;
+ }
return 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
}
@@ -701,26 +707,31 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
bool DefaultDestUndefined) override {
- if (!DefaultDestUndefined)
- addCost(2 * InstrCost);
// If suitable for a jump table, consider the cost for the table size and
// branch to destination.
// Maximum valid cost increased in this function.
if (JumpTableSize) {
+ // Suppose a default branch includes one compare and one conditional
+ // branch if it's reachable.
+ if (!DefaultDestUndefined)
+ addCost(2 * InstrCost);
+ // The jump table only requires a jump instruction.
int64_t JTCost =
- static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
+ static_cast<int64_t>(JumpTableSize) * InstrCost + InstrCost;
addCost(JTCost);
return;
}
if (NumCaseCluster <= 3) {
// Suppose a comparison includes one compare and one conditional branch.
- addCost(NumCaseCluster * 2 * InstrCost);
+ // We can reduce a set of instructions if the default branch is
+ // undefined.
+ addCost((NumCaseCluster - DefaultDestUndefined) * 2 * InstrCost);
return;
}
int64_t ExpectedNumberOfCompare =
- getExpectedNumberOfCompare(NumCaseCluster);
+ getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
addCost(SwitchCost);
@@ -1152,7 +1163,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
// FIXME: These constants are taken from the heuristic-based cost visitor.
// These should be removed entirely in a later revision to avoid reliance on
// heuristics in the ML inliner.
- static constexpr int JTCostMultiplier = 4;
+ static constexpr int JTCostMultiplier = 1;
static constexpr int CaseClusterCostMultiplier = 2;
static constexpr int SwitchDefaultDestCostMultiplier = 2;
static constexpr int SwitchCostMultiplier = 2;
@@ -1235,11 +1246,10 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
bool DefaultDestUndefined) override {
- if (!DefaultDestUndefined)
- increment(InlineCostFeatureIndex::switch_default_dest_penalty,
- SwitchDefaultDestCostMultiplier * InstrCost);
-
if (JumpTableSize) {
+ if (!DefaultDestUndefined)
+ increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+ SwitchDefaultDestCostMultiplier * InstrCost);
int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
JTCostMultiplier * InstrCost;
increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
@@ -1248,12 +1258,13 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
if (NumCaseCluster <= 3) {
increment(InlineCostFeatureIndex::case_cluster_penalty,
- NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
+ (NumCaseCluster - DefaultDestUndefined) *
+ CaseClusterCostMultiplier * InstrCost);
return;
}
int64_t ExpectedNumberOfCompare =
- getExpectedNumberOfCompare(NumCaseCluster);
+ getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
int64_t SwitchCost =
ExpectedNumberOfCompare * SwitchCostMultiplier * InstrCost;
diff --git a/llvm/test/Transforms/Inline/inline-switch-default-2.ll b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
index 8d3e24c798df82..1a648300ae3c1e 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default-2.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=21 | FileCheck %s
+; RUN: opt %s -S -passes=inline -inline-threshold=11 | FileCheck %s
; Check for scenarios without TTI.
diff --git a/llvm/test/Transforms/Inline/inline-switch-default.ll b/llvm/test/Transforms/Inline/inline-switch-default.ll
index 44f1304e82dff0..6a50820aad3a7d 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default.ll
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=26 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
-; RUN: opt %s -S -passes=inline -inline-threshold=21 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
+; RUN: opt %s -S -passes=inline -inline-threshold=16 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
+; RUN: opt %s -S -passes=inline -inline-threshold=11 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
|
On most targets other than X86 load+jump is 2 instructions. |
Thanks. I found that different targets can be quite different in different switch scene. But I think the patch will be at least a little closer now. |
Ping for review :) I am happy with the performance data on RISC-V. This patch saves ~8% instructions in some benchmarks. |
llvm/lib/Analysis/InlineCost.cpp
Outdated
if (DefaultDestUndefined) { | ||
return static_cast<int64_t>(NumCaseCluster) - 1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (DefaultDestUndefined) { | |
return static_cast<int64_t>(NumCaseCluster) - 1; | |
} | |
if (DefaultDestUndefined) | |
return static_cast<int64_t>(NumCaseCluster) - 1; |
Please drop the braces.
Is this heuristic correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example: https://llvm.godbolt.org/z/x6ETdfY79. If there are common target branches, the number of compare instructions will decrease. I haven't started learning about instruction selection, so I will leave a note here.
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
a587554
to
c786e6b
Compare
Ping. Since the judgment in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't put godbolt links in source code. If you absolutely need to refer to something, please file a bug and refer to that.
I'd like to see a test that actually verifies the computed cost here.
@@ -1,5 +1,5 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 | |||
; RUN: opt %s -S -passes=inline -inline-threshold=21 | FileCheck %s | |||
; RUN: opt %s -S -passes=inline -inline-threshold=11 | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the FileCheck lines in this file; there's a bunch of checks for stuff like "LOOKUPTABLE" which don't correspond to a RUN line. (Maybe push this separately.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated: 971ec1f
c786e6b
to
98b3234
Compare
The new issue is #90929.
Done. I'm sorry for the late response. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me
Fixes #81723.
I use the following patch to find functions that are not inlined after #77856.
patch.diff
There are over 20,000 call sites that don't satisfy the inline condition. I tried to select 10 of them:
Details
DoLowering(llvm::Function&, llvm::GCStrategy&)
FileName:llvm-project/llvm/lib/CodeGen/GCRootLowering.cpp
Line 201 in 5aec939
getFromRangeMetadata(llvm::Instruction*)
FileName:llvm-project/llvm/lib/Analysis/LazyValueInfo.cpp
Line 589 in 5aec939
clang::comments::DeclInfo::fill()
FileName:llvm-project/clang/lib/AST/Comment.cpp
Line 203 in 5aec939
clang::APValue::DestroyDataAndMakeUninit()
FileName:llvm-project/clang/lib/AST/APValue.cpp
Line 403 in 5aec939
clang::targets::MipsTargetInfo::getISARev() const
FileName:llvm-project/clang/lib/Basic/Targets/Mips.cpp
Line 61 in 5aec939
llvm::isLegalUTF8(unsigned char const*, int)
FileName:llvm-project/llvm/lib/Support/ConvertUTF.cpp
Line 397 in 5aec939
llvm::yaml::Input::createHNodes(llvm::yaml::Node*)
FileName:llvm-project/llvm/lib/Support/YAMLTraits.cpp
Line 401 in 5aec939
clang::Parser::ParseOpenACCDirective()
FileName:llvm-project/clang/lib/Parse/ParseOpenACC.cpp
Line 1119 in 5aec939
clang::OMPClauseReader::readClause()
FileName:llvm-project/clang/lib/Serialization/ASTReader.cpp
Line 10263 in 5aec939
clang::CodeGen::CodeGenFunction::EmitLandingPad()
FileName:llvm-project/clang/lib/CodeGen/CGException.cpp
Line 825 in 5aec939
There are complex switch statements that cannot be transformed to simpler structures.
The earliest commit of the related code is: 919f9e8. I tried to understand the following code with #77856 (comment).
llvm-project/llvm/lib/Analysis/InlineCost.cpp
Lines 709 to 720 in 5932fcc
I think only scenarios where there is a default branch were considered.
Taking https://llvm.godbolt.org/z/5cno1TnGx as an example, we need additional compare and jump instructions when there is a default branch, otherwise we just need a jump instruction.
But I don't know why it's
4 * InstrCost
and not3 * InstrCost
.Taking https://llvm.godbolt.org/z/MEsf9sno7 as an example, we can reduce a set of compare and jump instructions when the number of branches is small.
Further, I found that for scenarios where there are more branches. The generated compare instructions should be less than the number of branches if the default branch is undefined behavior. There will be fewer compare instructions if there are some common branches.
Revert the result of #77856: https://llvm-compile-time-tracker.com/compare.php?from=f3c5278efa3b783ada9e7a34b751cf4c5b864535&to=58622ef6755a02f97e5127bea29ed5b8812fe25e&stat=instructions:u.
New change: https://llvm-compile-time-tracker.com/compare.php?from=58622ef6755a02f97e5127bea29ed5b8812fe25e&to=dc2c2faa82d3d7b998680267a79895eb4969e6fd&stat=instructions%3Au.