Skip to content

Commit 732f536

Browse files
vzakharijoker-eph
andauthored
[RFC][mlir] Add profitability callback to the Inliner. (#84258)
Discussion at https://discourse.llvm.org/t/inliner-cost-model/2992 This change adds a callback that reports whether inlining of the particular call site (communicated via ResolvedCall argument) is profitable or not. The default MLIR inliner pass behavior is unchanged, i.e. the callback always returns true. This callback may be used to customize the inliner behavior based on the target specifics (like target instructions costs), profitability of the inlining for further optimizations (e.g. if inlining may enable loop optimizations or scalar optimizations due to object shape propagation), optimization levels (e.g. -Os inlining may be quite different from -Ofast inlining), etc. One of the questions is whether the ResolvedCall entity represents enough of the context for the custom inlining models to come up with the profitability decision. I think we can start with this and extend it as necessary. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 096ee4e commit 732f536

File tree

6 files changed

+92
-19
lines changed

6 files changed

+92
-19
lines changed

mlir/include/mlir/Transforms/Inliner.h

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,6 @@ class InlinerConfig {
6969
/// of inlining decisions from the leafs to the roots of the callgraph.
7070
class Inliner {
7171
public:
72-
using RunPipelineHelperTy = std::function<LogicalResult(
73-
Pass &pass, OpPassManager &pipeline, Operation *op)>;
74-
75-
Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am,
76-
RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config)
77-
: op(op), cg(cg), pass(pass), am(am),
78-
runPipelineHelper(std::move(runPipelineHelper)), config(config) {}
79-
Inliner(Inliner &) = delete;
80-
void operator=(const Inliner &) = delete;
81-
82-
/// Perform inlining on a OpTrait::SymbolTable operation.
83-
LogicalResult doInlining();
84-
8572
/// This struct represents a resolved call to a given callgraph node. Given
8673
/// that the call does not actually contain a direct reference to the
8774
/// Region(CallGraphNode) that it is dispatching to, we need to resolve them
@@ -94,7 +81,29 @@ class Inliner {
9481
CallGraphNode *sourceNode, *targetNode;
9582
};
9683

97-
protected:
84+
using RunPipelineHelperTy = std::function<LogicalResult(
85+
Pass &pass, OpPassManager &pipeline, Operation *op)>;
86+
87+
/// Type of the callback answering if it is profitable
88+
/// to inline a callable operation at a call site.
89+
/// It might be the case that the ResolvedCall does not provide
90+
/// enough context to make the profitability decision, so
91+
/// this hook's interface might need to be extended in future.
92+
using ProfitabilityCallbackTy = std::function<bool(const ResolvedCall &)>;
93+
94+
Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am,
95+
RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config,
96+
ProfitabilityCallbackTy isProfitableToInline)
97+
: op(op), cg(cg), pass(pass), am(am),
98+
runPipelineHelper(std::move(runPipelineHelper)), config(config),
99+
isProfitableToInline(std::move(isProfitableToInline)) {}
100+
Inliner(Inliner &) = delete;
101+
void operator=(const Inliner &) = delete;
102+
103+
/// Perform inlining on a OpTrait::SymbolTable operation.
104+
LogicalResult doInlining();
105+
106+
private:
98107
/// An OpTrait::SymbolTable operation to run the inlining on.
99108
Operation *op;
100109
/// A CallGraph analysis for the given operation.
@@ -108,12 +117,12 @@ class Inliner {
108117
const RunPipelineHelperTy runPipelineHelper;
109118
/// The inliner configuration parameters.
110119
const InlinerConfig &config;
120+
/// Returns true, if it is profitable to inline the callable operation
121+
/// at the call site.
122+
ProfitabilityCallbackTy isProfitableToInline;
111123

112-
private:
113124
/// Forward declaration of the class providing the actual implementation.
114125
class Impl;
115-
116-
public:
117126
};
118127
} // namespace mlir
119128

mlir/include/mlir/Transforms/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ def Inliner : Pass<"inline"> {
278278
Option<"maxInliningIterations", "max-iterations", "unsigned",
279279
/*default=*/"4",
280280
"Maximum number of iterations when inlining within an SCC">,
281+
Option<"inliningThreshold", "inlining-threshold", "unsigned",
282+
/*default=*/"-1U",
283+
"If the ratio between the number of the operations "
284+
"in the callee and the number of the operations "
285+
"in the caller exceeds this value (in percentage), "
286+
"then the callee is not inlined even if it is legal "
287+
"to inline it">,
281288
];
282289
}
283290

mlir/lib/Transforms/InlinerPass.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ namespace mlir {
2424
#include "mlir/Transforms/Passes.h.inc"
2525
} // namespace mlir
2626

27+
#define DEBUG_TYPE "inliner-pass"
28+
2729
using namespace mlir;
2830

2931
/// This function implements the inliner optimization pipeline.
@@ -88,6 +90,35 @@ InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
8890
config.setOpPipelines(std::move(opPipelines));
8991
}
9092

93+
// Return true if the inlining ratio does not exceed the threshold.
94+
static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
95+
unsigned inliningThreshold) {
96+
Region *callerRegion = resolvedCall.sourceNode->getCallableRegion();
97+
Region *calleeRegion = resolvedCall.targetNode->getCallableRegion();
98+
99+
// We should not get external nodes here, but just return true
100+
// for now to preserve the original behavior of the inliner pass.
101+
if (!calleeRegion || !calleeRegion)
102+
return true;
103+
104+
auto countOps = [](Region *region) {
105+
unsigned count = 0;
106+
region->walk([&](Operation *) { ++count; });
107+
return count;
108+
};
109+
110+
unsigned callerOps = countOps(callerRegion);
111+
112+
// Always inline empty callees (if it is possible at all).
113+
if (callerOps == 0)
114+
return true;
115+
116+
unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
117+
LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
118+
<< inliningThreshold << "%): " << ratio << "%\n");
119+
return ratio <= inliningThreshold;
120+
}
121+
91122
void InlinerPass::runOnOperation() {
92123
CallGraph &cg = getAnalysis<CallGraph>();
93124

@@ -100,9 +131,14 @@ void InlinerPass::runOnOperation() {
100131
return signalPassFailure();
101132
}
102133

134+
// By default, assume that any inlining is profitable.
135+
auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
136+
return isProfitableToInline(call, inliningThreshold);
137+
};
138+
103139
// Get an instance of the inliner.
104140
Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper,
105-
config);
141+
config, profitabilityCb);
106142

107143
// Run the inlining.
108144
if (failed(inliner.doInlining()))

mlir/lib/Transforms/Utils/Inliner.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,9 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
741741
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
742742
return false;
743743

744+
if (!inliner.isProfitableToInline(resolvedCall))
745+
return false;
746+
744747
// Otherwise, inline.
745748
return true;
746749
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
// RUN: mlir-opt %s -pass-pipeline="builtin.module(inline)" -dump-pass-pipeline 2>&1 | FileCheck %s
2-
// CHECK: builtin.module(inline{default-pipeline=canonicalize max-iterations=4 })
2+
// CHECK: builtin.module(inline{default-pipeline=canonicalize inlining-threshold=4294967295 max-iterations=4 })
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline='' inlining-threshold=100' -debug-only=inliner-pass 2>&1 | FileCheck %s
2+
3+
// Check that inlining does not happen when the threshold is exceeded.
4+
func.func @callee1(%arg : i32) -> i32 {
5+
%v1 = arith.addi %arg, %arg : i32
6+
%v2 = arith.addi %v1, %arg : i32
7+
%v3 = arith.addi %v2, %arg : i32
8+
return %v3 : i32
9+
}
10+
11+
// CHECK-LABEL: func @caller1
12+
func.func @caller1(%arg0 : i32) -> i32 {
13+
// CHECK-NEXT: call @callee1
14+
// CHECK-NEXT: return
15+
16+
%0 = call @callee1(%arg0) : (i32) -> i32
17+
return %0 : i32
18+
}

0 commit comments

Comments
 (0)