Skip to content

Commit d7eba20

Browse files
committed
[mlir][Inliner] Refactor the inliner to use nested pass pipelines instead of just canonicalization
Now that passes have support for running nested pipelines, the inliner can now allow for users to provide proper nested pipelines to use for optimization during inlining. This revision also changes the behavior of optimization during inlining to optimize before attempting to inline, which should lead to a more accurate cost model and prevents the need for users to schedule additional duplicate cleanup passes before/after the inliner that would already be run during inlining. Differential Revision: https://reviews.llvm.org/D91211
1 parent f0cd6aa commit d7eba20

File tree

16 files changed

+436
-203
lines changed

16 files changed

+436
-203
lines changed

llvm/include/llvm/ADT/Sequence.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class value_sequence_iterator
4242
value_sequence_iterator(const value_sequence_iterator &) = default;
4343
value_sequence_iterator(value_sequence_iterator &&Arg)
4444
: Value(std::move(Arg.Value)) {}
45+
value_sequence_iterator &operator=(const value_sequence_iterator &Arg) {
46+
Value = Arg.Value;
47+
return *this;
48+
}
4549

4650
template <typename U, typename Enabler = decltype(ValueT(std::declval<U>()))>
4751
value_sequence_iterator(U &&Value) : Value(std::forward<U>(Value)) {}

mlir/include/mlir/Pass/AnalysisManager.h

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct AnalysisConcept {
9898
/// A derived analysis model used to hold a specific analysis object.
9999
template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
100100
template <typename... Args>
101-
explicit AnalysisModel(Args &&... args)
101+
explicit AnalysisModel(Args &&...args)
102102
: analysis(std::forward<Args>(args)...) {}
103103

104104
/// A hook used to query analyses for invalidation.
@@ -198,19 +198,45 @@ class AnalysisMap {
198198
/// An analysis map that contains a map for the current operation, and a set of
199199
/// maps for any child operations.
200200
struct NestedAnalysisMap {
201-
NestedAnalysisMap(Operation *op) : analyses(op) {}
201+
NestedAnalysisMap(Operation *op, PassInstrumentor *instrumentor)
202+
: analyses(op), parentOrInstrumentor(instrumentor) {}
203+
NestedAnalysisMap(Operation *op, NestedAnalysisMap *parent)
204+
: analyses(op), parentOrInstrumentor(parent) {}
202205

203206
/// Get the operation for this analysis map.
204207
Operation *getOperation() const { return analyses.getOperation(); }
205208

206209
/// Invalidate any non preserved analyses.
207210
void invalidate(const PreservedAnalyses &pa);
208211

212+
/// Returns the parent analysis map for this analysis map, or null if this is
213+
/// the top-level map.
214+
const NestedAnalysisMap *getParent() const {
215+
return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
216+
}
217+
218+
/// Returns a pass instrumentation object for the current operation. This
219+
/// value may be null.
220+
PassInstrumentor *getPassInstrumentor() const {
221+
if (auto *parent = getParent())
222+
return parent->getPassInstrumentor();
223+
return parentOrInstrumentor.get<PassInstrumentor *>();
224+
}
225+
209226
/// The cached analyses for nested operations.
210227
DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
211228

212-
/// The analyses for the owning module.
229+
/// The analyses for the owning operation.
213230
detail::AnalysisMap analyses;
231+
232+
/// This value has three possible states:
233+
/// NestedAnalysisMap*: A pointer to the parent analysis map.
234+
/// PassInstrumentor*: This analysis map is the top-level map, and this
235+
/// pointer is the optional pass instrumentor for the
236+
/// current compilation.
237+
/// nullptr: This analysis map is the top-level map, and there is nop pass
238+
/// instrumentor.
239+
PointerUnion<NestedAnalysisMap *, PassInstrumentor *> parentOrInstrumentor;
214240
};
215241
} // namespace detail
216242

@@ -236,11 +262,11 @@ class AnalysisManager {
236262
template <typename AnalysisT>
237263
Optional<std::reference_wrapper<AnalysisT>>
238264
getCachedParentAnalysis(Operation *parentOp) const {
239-
ParentPointerT curParent = parent;
240-
while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>()) {
241-
if (parentAM->impl->getOperation() == parentOp)
242-
return parentAM->getCachedAnalysis<AnalysisT>();
243-
curParent = parentAM->parent;
265+
const detail::NestedAnalysisMap *curParent = impl;
266+
while (auto *parentAM = curParent->getParent()) {
267+
if (parentAM->getOperation() == parentOp)
268+
return parentAM->analyses.getCachedAnalysis<AnalysisT>();
269+
curParent = parentAM;
244270
}
245271
return None;
246272
}
@@ -286,7 +312,8 @@ class AnalysisManager {
286312
return it->second->analyses.getCachedAnalysis<AnalysisT>();
287313
}
288314

289-
/// Get an analysis manager for the given child operation.
315+
/// Get an analysis manager for the given operation, which must be a proper
316+
/// descendant of the current operation represented by this analysis manager.
290317
AnalysisManager nest(Operation *op);
291318

292319
/// Invalidate any non preserved analyses,
@@ -300,19 +327,15 @@ class AnalysisManager {
300327

301328
/// Returns a pass instrumentation object for the current operation. This
302329
/// value may be null.
303-
PassInstrumentor *getPassInstrumentor() const;
330+
PassInstrumentor *getPassInstrumentor() const {
331+
return impl->getPassInstrumentor();
332+
}
304333

305334
private:
306-
AnalysisManager(const AnalysisManager *parent,
307-
detail::NestedAnalysisMap *impl)
308-
: parent(parent), impl(impl) {}
309-
AnalysisManager(const ModuleAnalysisManager *parent,
310-
detail::NestedAnalysisMap *impl)
311-
: parent(parent), impl(impl) {}
335+
AnalysisManager(detail::NestedAnalysisMap *impl) : impl(impl) {}
312336

313-
/// A reference to the parent analysis manager, or the top-level module
314-
/// analysis manager.
315-
ParentPointerT parent;
337+
/// Get an analysis manager for the given immediately nested child operation.
338+
AnalysisManager nestImmediate(Operation *op);
316339

317340
/// A reference to the impl analysis map within the parent analysis manager.
318341
detail::NestedAnalysisMap *impl;
@@ -328,23 +351,16 @@ class AnalysisManager {
328351
class ModuleAnalysisManager {
329352
public:
330353
ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor)
331-
: analyses(op), passInstrumentor(passInstrumentor) {}
354+
: analyses(op, passInstrumentor) {}
332355
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
333356
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
334357

335-
/// Returns a pass instrumentation object for the current module. This value
336-
/// may be null.
337-
PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; }
338-
339358
/// Returns an analysis manager for the current top-level module.
340-
operator AnalysisManager() { return AnalysisManager(this, &analyses); }
359+
operator AnalysisManager() { return AnalysisManager(&analyses); }
341360

342361
private:
343362
/// The analyses for the owning module.
344363
detail::NestedAnalysisMap analyses;
345-
346-
/// An optional instrumentation object.
347-
PassInstrumentor *passInstrumentor;
348364
};
349365

350366
} // end namespace mlir

mlir/include/mlir/Pass/Pass.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class Pass {
9595
typename OptionParser = detail::PassOptions::OptionParser<DataType>>
9696
struct Option : public detail::PassOptions::Option<DataType, OptionParser> {
9797
template <typename... Args>
98-
Option(Pass &parent, StringRef arg, Args &&... args)
98+
Option(Pass &parent, StringRef arg, Args &&...args)
9999
: detail::PassOptions::Option<DataType, OptionParser>(
100100
parent.passOptions, arg, std::forward<Args>(args)...) {}
101101
using detail::PassOptions::Option<DataType, OptionParser>::operator=;
@@ -107,14 +107,17 @@ class Pass {
107107
struct ListOption
108108
: public detail::PassOptions::ListOption<DataType, OptionParser> {
109109
template <typename... Args>
110-
ListOption(Pass &parent, StringRef arg, Args &&... args)
110+
ListOption(Pass &parent, StringRef arg, Args &&...args)
111111
: detail::PassOptions::ListOption<DataType, OptionParser>(
112112
parent.passOptions, arg, std::forward<Args>(args)...) {}
113113
using detail::PassOptions::ListOption<DataType, OptionParser>::operator=;
114114
};
115115

116116
/// Attempt to initialize the options of this pass from the given string.
117-
LogicalResult initializeOptions(StringRef options);
117+
/// Derived classes may override this method to hook into the point at which
118+
/// options are initialized, but should generally always invoke this base
119+
/// class variant.
120+
virtual LogicalResult initializeOptions(StringRef options);
118121

119122
/// Prints out the pass in the textual representation of pipelines. If this is
120123
/// an adaptor pass, print with the op_name(sub_pass,...) format.
@@ -265,7 +268,6 @@ class Pass {
265268
void copyOptionValuesFrom(const Pass *other);
266269

267270
private:
268-
269271
/// Out of line virtual method to ensure vtables and metadata are emitted to a
270272
/// single .o file.
271273
virtual void anchor();

mlir/include/mlir/Pass/PassManager.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ struct PassExecutionState;
4848
class OpPassManager {
4949
public:
5050
enum class Nesting { Implicit, Explicit };
51-
OpPassManager(Identifier name, Nesting nesting);
52-
OpPassManager(StringRef name, Nesting nesting);
51+
OpPassManager(Identifier name, Nesting nesting = Nesting::Explicit);
52+
OpPassManager(StringRef name, Nesting nesting = Nesting::Explicit);
5353
OpPassManager(OpPassManager &&rhs);
5454
OpPassManager(const OpPassManager &rhs);
5555
~OpPassManager();

mlir/include/mlir/Transforms/Passes.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ std::unique_ptr<Pass> createPrintOpStatsPass();
107107
/// Creates a pass which inlines calls and callable operations as defined by
108108
/// the CallGraph.
109109
std::unique_ptr<Pass> createInlinerPass();
110+
/// Creates an instance of the inliner pass, and use the provided pass managers
111+
/// when optimizing callable operations with names matching the key type.
112+
/// Callable operations with a name not within the provided map will use the
113+
/// default inliner pipeline during optimization.
114+
std::unique_ptr<Pass>
115+
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines);
116+
/// Creates an instance of the inliner pass, and use the provided pass managers
117+
/// when optimizing callable operations with names matching the key type.
118+
/// Callable operations with a name not within the provided map will use the
119+
/// provided default pipeline builder.
120+
std::unique_ptr<Pass>
121+
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
122+
std::function<void(OpPassManager &)> defaultPipelineBuilder);
110123

111124
/// Creates a pass which performs sparse conditional constant propagation over
112125
/// nested operations.

mlir/include/mlir/Transforms/Passes.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,12 @@ def Inliner : Pass<"inline"> {
285285
let summary = "Inline function calls";
286286
let constructor = "mlir::createInlinerPass()";
287287
let options = [
288-
Option<"disableCanonicalization", "disable-simplify", "bool",
289-
/*default=*/"false",
290-
"Disable running simplifications during inlining">,
288+
Option<"defaultPipelineStr", "default-pipeline", "std::string",
289+
/*default=*/"", "The default optimizer pipeline used for callables">,
290+
ListOption<"opPipelineStrs", "op-pipelines", "std::string",
291+
"Callable operation specific optimizer pipelines (in the form "
292+
"of `dialect.op(pipeline)`)",
293+
"llvm::cl::MiscFlags::CommaSeparated">,
291294
Option<"maxInliningIterations", "max-iterations", "unsigned",
292295
/*default=*/"4",
293296
"Maximum number of iterations when inlining within an SCC">,

mlir/lib/Pass/Pass.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -340,22 +340,25 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
340340

341341
// Initialize the pass state with a callback for the pass to dynamically
342342
// execute a pipeline on the currently visited operation.
343-
auto dynamic_pipeline_callback =
344-
[op, &am, verifyPasses](OpPassManager &pipeline,
345-
Operation *root) -> LogicalResult {
343+
PassInstrumentor *pi = am.getPassInstrumentor();
344+
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
345+
pass};
346+
auto dynamic_pipeline_callback = [&](OpPassManager &pipeline,
347+
Operation *root) -> LogicalResult {
346348
if (!op->isAncestor(root))
347349
return root->emitOpError()
348350
<< "Trying to schedule a dynamic pipeline on an "
349351
"operation that isn't "
350352
"nested under the current operation the pass is processing";
353+
assert(pipeline.getOpName() == root->getName().getStringRef());
351354

352-
AnalysisManager nestedAm = am.nest(root);
355+
AnalysisManager nestedAm = root == op ? am : am.nest(root);
353356
return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
354-
verifyPasses);
357+
verifyPasses, pi, &parentInfo);
355358
};
356359
pass->passState.emplace(op, am, dynamic_pipeline_callback);
360+
357361
// Instrument before the pass has run.
358-
PassInstrumentor *pi = am.getPassInstrumentor();
359362
if (pi)
360363
pi->runBeforePass(pass, op);
361364

@@ -388,7 +391,10 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
388391
/// Run the given operation and analysis manager on a provided op pass manager.
389392
LogicalResult OpToOpPassAdaptor::runPipeline(
390393
iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
391-
AnalysisManager am, bool verifyPasses) {
394+
AnalysisManager am, bool verifyPasses, PassInstrumentor *instrumentor,
395+
const PassInstrumentation::PipelineParentInfo *parentInfo) {
396+
assert((!instrumentor || parentInfo) &&
397+
"expected parent info if instrumentor is provided");
392398
auto scope_exit = llvm::make_scope_exit([&] {
393399
// Clear out any computed operation analyses. These analyses won't be used
394400
// any more in this pipeline, and this helps reduce the current working set
@@ -398,10 +404,13 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
398404
});
399405

400406
// Run the pipeline over the provided operation.
407+
if (instrumentor)
408+
instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo);
401409
for (Pass &pass : passes)
402410
if (failed(run(&pass, op, am, verifyPasses)))
403411
return failure();
404-
412+
if (instrumentor)
413+
instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo);
405414
return success();
406415
}
407416

@@ -491,17 +500,10 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
491500
*op.getContext());
492501
if (!mgr)
493502
continue;
494-
Identifier opName = mgr->getOpName(*getOperation()->getContext());
495503

496504
// Run the held pipeline over the current operation.
497-
if (instrumentor)
498-
instrumentor->runBeforePipeline(opName, parentInfo);
499-
LogicalResult result =
500-
runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses);
501-
if (instrumentor)
502-
instrumentor->runAfterPipeline(opName, parentInfo);
503-
504-
if (failed(result))
505+
if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op),
506+
verifyPasses, instrumentor, &parentInfo)))
505507
return signalPassFailure();
506508
}
507509
}
@@ -576,13 +578,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
576578
pms, it.first->getName().getIdentifier(), getContext());
577579
assert(pm && "expected valid pass manager for operation");
578580

579-
Identifier opName = pm->getOpName(*getOperation()->getContext());
580-
if (instrumentor)
581-
instrumentor->runBeforePipeline(opName, parentInfo);
582-
auto pipelineResult =
583-
runPipeline(pm->getPasses(), it.first, it.second, verifyPasses);
584-
if (instrumentor)
585-
instrumentor->runAfterPipeline(opName, parentInfo);
581+
LogicalResult pipelineResult =
582+
runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
583+
instrumentor, &parentInfo);
586584

587585
// Drop this thread from being tracked by the diagnostic handler.
588586
// After this task has finished, the thread may be used outside of
@@ -848,22 +846,41 @@ void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
848846
// AnalysisManager
849847
//===----------------------------------------------------------------------===//
850848

851-
/// Returns a pass instrumentation object for the current operation.
852-
PassInstrumentor *AnalysisManager::getPassInstrumentor() const {
853-
ParentPointerT curParent = parent;
854-
while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>())
855-
curParent = parentAM->parent;
856-
return curParent.get<const ModuleAnalysisManager *>()->getPassInstrumentor();
849+
/// Get an analysis manager for the given operation, which must be a proper
850+
/// descendant of the current operation represented by this analysis manager.
851+
AnalysisManager AnalysisManager::nest(Operation *op) {
852+
Operation *currentOp = impl->getOperation();
853+
assert(currentOp->isProperAncestor(op) &&
854+
"expected valid descendant operation");
855+
856+
// Check for the base case where the provided operation is immediately nested.
857+
if (currentOp == op->getParentOp())
858+
return nestImmediate(op);
859+
860+
// Otherwise, we need to collect all ancestors up to the current operation.
861+
SmallVector<Operation *, 4> opAncestors;
862+
do {
863+
opAncestors.push_back(op);
864+
op = op->getParentOp();
865+
} while (op != currentOp);
866+
867+
AnalysisManager result = *this;
868+
for (Operation *op : llvm::reverse(opAncestors))
869+
result = result.nestImmediate(op);
870+
return result;
857871
}
858872

859-
/// Get an analysis manager for the given child operation.
860-
AnalysisManager AnalysisManager::nest(Operation *op) {
873+
/// Get an analysis manager for the given immediately nested child operation.
874+
AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
875+
assert(impl->getOperation() == op->getParentOp() &&
876+
"expected immediate child operation");
877+
861878
auto it = impl->childAnalyses.find(op);
862879
if (it == impl->childAnalyses.end())
863880
it = impl->childAnalyses
864-
.try_emplace(op, std::make_unique<NestedAnalysisMap>(op))
881+
.try_emplace(op, std::make_unique<NestedAnalysisMap>(op, impl))
865882
.first;
866-
return {this, it->second.get()};
883+
return {it->second.get()};
867884
}
868885

869886
/// Invalidate any non preserved analyses.

mlir/lib/Pass/PassDetail.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ class OpToOpPassAdaptor
6060

6161
/// Run the given operation and analysis manager on a provided op pass
6262
/// manager.
63-
static LogicalResult
64-
runPipeline(iterator_range<OpPassManager::pass_iterator> passes,
65-
Operation *op, AnalysisManager am, bool verifyPasses);
63+
static LogicalResult runPipeline(
64+
iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
65+
AnalysisManager am, bool verifyPasses,
66+
PassInstrumentor *instrumentor = nullptr,
67+
const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr);
6668

6769
/// A set of adaptors to run.
6870
SmallVector<OpPassManager, 1> mgrs;

0 commit comments

Comments
 (0)