Skip to content

Commit 9c9a431

Browse files
committed
[mlir][Pass] Add support for an InterfacePass and pass filtering based on OperationName
This commit adds a new hook Pass `bool canScheduleOn(RegisteredOperationName)` that indicates if the given pass can be scheduled on operations of the given type. This makes it easier to define constraints on generic passes without a) adding conditional checks to the beginning of the `runOnOperation`, or b) defining a new pass type that forwards from `runOnOperation` (after checking the invariants) to a new hook. This new hook is used to implement an `InterfacePass` pass class, that represents a generic pass that runs on operations of the given interface type. The PassManager will also verify that passes added to a pass manager can actually be scheduled on that pass manager, meaning that we will properly error when an Interface is scheduled on an operation that doesn't actually implement that interface. Differential Revision: https://reviews.llvm.org/D120791
1 parent 449b649 commit 9c9a431

File tree

8 files changed

+143
-44
lines changed

8 files changed

+143
-44
lines changed

mlir/include/mlir/Pass/Pass.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
namespace mlir {
2121
namespace detail {
2222
class OpToOpPassAdaptor;
23+
struct OpPassManagerImpl;
2324

2425
/// The state for a single execution of a pass. This provides a unified
2526
/// interface for accessing and initializing necessary state for pass execution.
@@ -184,6 +185,11 @@ class Pass {
184185
/// pipeline won't execute.
185186
virtual LogicalResult initialize(MLIRContext *context) { return success(); }
186187

188+
/// Indicate if the current pass can be scheduled on the given operation type.
189+
/// This is useful for generic operation passes to add restrictions on the
190+
/// operations they operate on.
191+
virtual bool canScheduleOn(RegisteredOperationName opName) const = 0;
192+
187193
/// Schedule an arbitrary pass pipeline on the provided operation.
188194
/// This can be invoke any time in a pass to dynamic schedule more passes.
189195
/// The provided operation must be the current one or one nested below.
@@ -313,6 +319,9 @@ class Pass {
313319
/// Allow access to 'clone'.
314320
friend class OpPassManager;
315321

322+
/// Allow access to 'canScheduleOn'.
323+
friend detail::OpPassManagerImpl;
324+
316325
/// Allow access to 'passState'.
317326
friend detail::OpToOpPassAdaptor;
318327

@@ -346,6 +355,11 @@ template <typename OpT = void> class OperationPass : public Pass {
346355
return pass->getOpName() == OpT::getOperationName();
347356
}
348357

358+
/// Indicate if the current pass can be scheduled on the given operation type.
359+
bool canScheduleOn(RegisteredOperationName opName) const final {
360+
return opName.getStringRef() == getOpName();
361+
}
362+
349363
/// Return the current operation being transformed.
350364
OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
351365

@@ -373,6 +387,46 @@ template <> class OperationPass<void> : public Pass {
373387
protected:
374388
OperationPass(TypeID passID) : Pass(passID) {}
375389
OperationPass(const OperationPass &) = default;
390+
391+
/// Indicate if the current pass can be scheduled on the given operation type.
392+
/// By default, generic operation passes can be scheduled on any operation.
393+
bool canScheduleOn(RegisteredOperationName opName) const override {
394+
return true;
395+
}
396+
};
397+
398+
/// Pass to transform an operation that implements the given interface.
399+
///
400+
/// Interface passes must not:
401+
/// - modify any other operations within the parent region, as other threads
402+
/// may be manipulating them concurrently.
403+
/// - modify any state within the parent operation, this includes adding
404+
/// additional operations.
405+
///
406+
/// Derived interface passes are expected to provide the following:
407+
/// - A 'void runOnOperation()' method.
408+
/// - A 'StringRef getName() const' method.
409+
/// - A 'std::unique_ptr<Pass> clonePass() const' method.
410+
template <typename InterfaceT>
411+
class InterfacePass : public OperationPass<> {
412+
protected:
413+
using OperationPass::OperationPass;
414+
415+
/// Indicate if the current pass can be scheduled on the given operation type.
416+
/// For an InterfacePass, this checks if the operation implements the given
417+
/// interface.
418+
bool canScheduleOn(RegisteredOperationName opName) const final {
419+
return opName.hasInterface<InterfaceT>();
420+
}
421+
422+
/// Return the current operation being transformed.
423+
InterfaceT getOperation() { return cast<InterfaceT>(Pass::getOperation()); }
424+
425+
/// Query an analysis for the current operation.
426+
template <typename AnalysisT>
427+
AnalysisT &getAnalysis() {
428+
return Pass::getAnalysis<AnalysisT, InterfaceT>();
429+
}
376430
};
377431

378432
/// This class provides a CRTP wrapper around a base pass class to define

mlir/include/mlir/Pass/PassBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,8 @@ class PassBase<string passArg, string base> {
9292
class Pass<string passArg, string operation = "">
9393
: PassBase<passArg, "::mlir::OperationPass<" # operation # ">">;
9494

95+
// This class represents an mlir::InterfacePass.
96+
class InterfacePass<string passArg, string interface>
97+
: PassBase<passArg, "::mlir::InterfacePass<" # interface # ">">;
98+
9599
#endif // MLIR_PASS_PASSBASE

mlir/include/mlir/Pass/PassManager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class OpPassManager {
9898
size_t size() const;
9999

100100
/// Return the operation name that this pass manager operates on.
101-
StringAttr getOpName(MLIRContext &context) const;
101+
OperationName getOpName(MLIRContext &context) const;
102102

103103
/// Return the operation name that this pass manager operates on.
104104
StringRef getOpName() const;

mlir/lib/Pass/Pass.cpp

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ void Pass::printAsTextualPipeline(raw_ostream &os) {
8080
namespace mlir {
8181
namespace detail {
8282
struct OpPassManagerImpl {
83-
OpPassManagerImpl(StringAttr identifier, OpPassManager::Nesting nesting)
84-
: name(identifier.str()), identifier(identifier),
83+
OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
84+
: name(opName.getStringRef()), opName(opName),
8585
initializationGeneration(0), nesting(nesting) {}
8686
OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
8787
: name(name), initializationGeneration(0), nesting(nesting) {}
@@ -102,23 +102,24 @@ struct OpPassManagerImpl {
102102
/// preserved.
103103
void clear();
104104

105-
/// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
106-
/// recursively through the pipeline graph.
107-
void coalesceAdjacentAdaptorPasses();
105+
/// Finalize the pass list in preparation for execution. This includes
106+
/// coalescing adjacent pass managers when possible, verifying scheduled
107+
/// passes, etc.
108+
LogicalResult finalizePassList(MLIRContext *ctx);
108109

109-
/// Return the operation name of this pass manager as an identifier.
110-
StringAttr getOpName(MLIRContext &context) {
111-
if (!identifier)
112-
identifier = StringAttr::get(&context, name);
113-
return *identifier;
110+
/// Return the operation name of this pass manager.
111+
OperationName getOpName(MLIRContext &context) {
112+
if (!opName)
113+
opName = OperationName(name, &context);
114+
return *opName;
114115
}
115116

116117
/// The name of the operation that passes of this pass manager operate on.
117118
std::string name;
118119

119-
/// The cached identifier (internalized in the context) for the name of the
120+
/// The cached OperationName (internalized in the context) for the name of the
120121
/// operation that passes of this pass manager operate on.
121-
Optional<StringAttr> identifier;
122+
Optional<OperationName> opName;
122123

123124
/// The set of passes to run as part of this pass manager.
124125
std::vector<std::unique_ptr<Pass>> passes;
@@ -173,18 +174,12 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
173174

174175
void OpPassManagerImpl::clear() { passes.clear(); }
175176

176-
void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
177-
// Bail out early if there are no adaptor passes.
178-
if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
179-
return isa<OpToOpPassAdaptor>(pass.get());
180-
}))
181-
return;
182-
177+
LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) {
183178
// Walk the pass list and merge adjacent adaptors.
184179
OpToOpPassAdaptor *lastAdaptor = nullptr;
185-
for (auto &passe : passes) {
180+
for (auto &pass : passes) {
186181
// Check to see if this pass is an adaptor.
187-
if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(passe.get())) {
182+
if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get())) {
188183
// If it is the first adaptor in a possible chain, remember it and
189184
// continue.
190185
if (!lastAdaptor) {
@@ -194,25 +189,39 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
194189

195190
// Otherwise, merge into the existing adaptor and delete the current one.
196191
currentAdaptor->mergeInto(*lastAdaptor);
197-
passe.reset();
192+
pass.reset();
198193
} else if (lastAdaptor) {
199-
// If this pass is not an adaptor, then coalesce and forget any existing
194+
// If this pass is not an adaptor, then finalize and forget any existing
200195
// adaptor.
201196
for (auto &pm : lastAdaptor->getPassManagers())
202-
pm.getImpl().coalesceAdjacentAdaptorPasses();
197+
if (failed(pm.getImpl().finalizePassList(ctx)))
198+
return failure();
203199
lastAdaptor = nullptr;
204200
}
205201
}
206202

207-
// If there was an adaptor at the end of the manager, coalesce it as well.
203+
// If there was an adaptor at the end of the manager, finalize it as well.
208204
if (lastAdaptor) {
209205
for (auto &pm : lastAdaptor->getPassManagers())
210-
pm.getImpl().coalesceAdjacentAdaptorPasses();
206+
if (failed(pm.getImpl().finalizePassList(ctx)))
207+
return failure();
211208
}
212209

213-
// Now that the adaptors have been merged, erase the empty slot corresponding
210+
// Now that the adaptors have been merged, erase any empty slots corresponding
214211
// to the merged adaptors that were nulled-out in the loop above.
212+
Optional<RegisteredOperationName> opName =
213+
getOpName(*ctx).getRegisteredInfo();
215214
llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
215+
216+
// Verify that all of the passes are valid for the operation.
217+
for (std::unique_ptr<Pass> &pass : passes) {
218+
if (opName && !pass->canScheduleOn(*opName)) {
219+
return emitError(UnknownLoc::get(ctx))
220+
<< "unable to schedule pass '" << pass->getName()
221+
<< "' on a PassManager intended to run on '" << name << "'!";
222+
}
223+
}
224+
return success();
216225
}
217226

218227
//===----------------------------------------------------------------------===//
@@ -279,7 +288,7 @@ OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
279288
StringRef OpPassManager::getOpName() const { return impl->name; }
280289

281290
/// Return the operation name that this pass manager operates on.
282-
StringAttr OpPassManager::getOpName(MLIRContext &context) const {
291+
OperationName OpPassManager::getOpName(MLIRContext &context) const {
283292
return impl->getOpName(context);
284293
}
285294

@@ -367,9 +376,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
367376
"nested under the current operation the pass is processing";
368377
assert(pipeline.getOpName() == root->getName().getStringRef());
369378

370-
// Before running, make sure to coalesce any adjacent pass adaptors in the
371-
// pipeline.
372-
pipeline.getImpl().coalesceAdjacentAdaptorPasses();
379+
// Before running, finalize the passes held by the pipeline.
380+
if (failed(pipeline.getImpl().finalizePassList(root->getContext())))
381+
return failure();
373382

374383
// Initialize the user provided pipeline and execute the pipeline.
375384
if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
@@ -468,7 +477,7 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
468477
/// Find an operation pass manager that can operate on an operation of the given
469478
/// type, or nullptr if one does not exist.
470479
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
471-
StringAttr name,
480+
OperationName name,
472481
MLIRContext &context) {
473482
auto *it = llvm::find_if(
474483
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
@@ -538,8 +547,7 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
538547
for (auto &region : getOperation()->getRegions()) {
539548
for (auto &block : region) {
540549
for (auto &op : block) {
541-
auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
542-
*op.getContext());
550+
auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
543551
if (!mgr)
544552
continue;
545553

@@ -581,7 +589,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
581589
for (auto &block : region) {
582590
for (auto &op : block) {
583591
// Add this operation iff the name matches any of the pass managers.
584-
if (findPassManagerFor(mgrs, op.getName().getIdentifier(), *context))
592+
if (findPassManagerFor(mgrs, op.getName(), *context))
585593
opAMPairs.emplace_back(&op, am.nest(&op));
586594
}
587595
}
@@ -604,9 +612,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
604612
unsigned pmIndex = it - activePMs.begin();
605613

606614
// Get the pass manager for this operation and execute it.
607-
auto *pm =
608-
findPassManagerFor(asyncExecutors[pmIndex],
609-
opPMPair.first->getName().getIdentifier(), *context);
615+
auto *pm = findPassManagerFor(asyncExecutors[pmIndex],
616+
opPMPair.first->getName(), *context);
610617
assert(pm && "expected valid pass manager for operation");
611618

612619
unsigned initGeneration = pm->impl->initializationGeneration;
@@ -641,21 +648,21 @@ void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
641648
/// Run the passes within this manager on the provided operation.
642649
LogicalResult PassManager::run(Operation *op) {
643650
MLIRContext *context = getContext();
644-
assert(op->getName().getIdentifier() == getOpName(*context) &&
651+
assert(op->getName() == getOpName(*context) &&
645652
"operation has a different name than the PassManager or is from a "
646653
"different context");
647654

648-
// Before running, make sure to coalesce any adjacent pass adaptors in the
649-
// pipeline.
650-
getImpl().coalesceAdjacentAdaptorPasses();
651-
652655
// Register all dialects for the current pipeline.
653656
DialectRegistry dependentDialects;
654657
getDependentDialects(dependentDialects);
655658
context->appendDialectRegistry(dependentDialects);
656659
for (StringRef name : dependentDialects.getDialectNames())
657660
context->getOrLoadDialect(name);
658661

662+
// Before running, make sure to finalize the pipeline pass list.
663+
if (failed(getImpl().finalizePassList(context)))
664+
return failure();
665+
659666
// Initialize all of the passes within the pass manager with a new generation.
660667
llvm::hash_code newInitKey = context->getRegistryHash();
661668
if (newInitKey != initializationKey) {

mlir/test/Pass/interface-pass.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-opt %s -verify-diagnostics -pass-pipeline='builtin.func(test-interface-pass)' -o /dev/null
2+
3+
// Test that we run the interface pass on the function.
4+
5+
// expected-remark@below {{Executing interface pass on operation}}
6+
func @main() {
7+
return
8+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: not mlir-opt %s -pass-pipeline='test-interface-pass' 2>&1 | FileCheck %s
2+
3+
// Test that we emit an error when an interface pass is added to a pass manager it can't be scheduled on.
4+
5+
// CHECK: unable to schedule pass '{{.*}}' on a PassManager intended to run on 'builtin.module'!
6+
7+
func @main() {
8+
return
9+
}

mlir/test/lib/Pass/TestPassManager.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ struct TestFunctionPass
2929
return "Test a function pass in the pass manager";
3030
}
3131
};
32+
class TestInterfacePass
33+
: public PassWrapper<TestInterfacePass,
34+
InterfacePass<FunctionOpInterface>> {
35+
void runOnOperation() final {
36+
getOperation()->emitRemark() << "Executing interface pass on operation";
37+
}
38+
StringRef getArgument() const final { return "test-interface-pass"; }
39+
StringRef getDescription() const final {
40+
return "Test an interface pass (running on FunctionOpInterface) in the "
41+
"pass manager";
42+
}
43+
};
3244
class TestOptionsPass
3345
: public PassWrapper<TestOptionsPass, OperationPass<FuncOp>> {
3446
public:
@@ -128,6 +140,8 @@ void registerPassManagerTestPass() {
128140

129141
PassRegistration<TestFunctionPass>();
130142

143+
PassRegistration<TestInterfacePass>();
144+
131145
PassRegistration<TestCrashRecoveryPass>();
132146
PassRegistration<TestFailurePass>();
133147

mlir/unittests/Pass/PassManagerTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ struct InvalidPass : Pass {
8181
InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
8282
StringRef getName() const override { return "Invalid Pass"; }
8383
void runOnOperation() override {}
84+
bool canScheduleOn(RegisteredOperationName opName) const override {
85+
return true;
86+
}
8487

8588
/// A clone method to create a copy of this pass.
8689
std::unique_ptr<Pass> clonePass() const override {

0 commit comments

Comments
 (0)