Skip to content

Commit 46708a5

Browse files
authored
[mlir][Pass] Move PassExecutionAction to Pass.h, NFC. (#74850)
This patch moves PassExecutionAction to Pass.h so that it can be used by the action framework to introspect and intercede in pass managers that might be set up opaquely. This provides for a very particular use case, which essentially involves being able to intercede in a PassManager and skip or apply individual passes. Because of this, this patch also adds a test for this use case to verify that it could in fact work.
1 parent 687e63a commit 46708a5

File tree

5 files changed

+158
-20
lines changed

5 files changed

+158
-20
lines changed

mlir/include/mlir/Pass/Pass.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_PASS_PASS_H
1010
#define MLIR_PASS_PASS_H
1111

12+
#include "mlir/IR/Action.h"
1213
#include "mlir/Pass/AnalysisManager.h"
1314
#include "mlir/Pass/PassRegistry.h"
1415
#include "mlir/Support/LogicalResult.h"
@@ -457,6 +458,52 @@ class PassWrapper : public BaseT {
457458
}
458459
};
459460

461+
/// This class encapsulates the "action" of executing a single pass. This allows
462+
/// a user of the Action infrastructure to query information about an action in
463+
/// (for example) a breakpoint context. You could use it like this:
464+
///
465+
/// auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
466+
/// if (auto passExec = dyn_cast<PassExecutionAction>(anAction))
467+
/// record(passExec.getPass());
468+
/// return ExecutionContext::Apply;
469+
/// };
470+
/// ExecutionContext exeCtx(onBreakpoint);
471+
///
472+
class PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
473+
using Base = tracing::ActionImpl<PassExecutionAction>;
474+
475+
public:
476+
/// Define a TypeID for this PassExecutionAction.
477+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PassExecutionAction)
478+
/// Construct a PassExecutionAction. This is called by the OpToOpPassAdaptor
479+
/// when it calls `executeAction`.
480+
PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass);
481+
482+
/// The tag required by ActionImpl to identify this action.
483+
static constexpr StringLiteral tag = "pass-execution";
484+
485+
/// Print a textual version of this action to `os`.
486+
void print(raw_ostream &os) const override;
487+
488+
/// Get the pass that will be executed by this action. This is not a class of
489+
/// passes, or all instances of a pass kind, this is a single pass.
490+
const Pass &getPass() const { return pass; }
491+
492+
/// Get the operation that is the base of this pass. For example, an
493+
/// OperationPass<ModuleOp> would return a ModuleOp.
494+
Operation *getOp() const;
495+
496+
public:
497+
/// Reference to the pass being run. Notice that this will *not* extend the
498+
/// lifetime of the pass, and so this class is therefore unsafe to keep past
499+
/// the lifetime of the `executeAction` call.
500+
const Pass &pass;
501+
502+
/// The base op for this pass. For an OperationPass<ModuleOp>, we would have a
503+
/// ModuleOp here.
504+
Operation *op;
505+
};
506+
460507
} // namespace mlir
461508

462509
#endif // MLIR_PASS_PASS_H

mlir/lib/Pass/Pass.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,21 @@ using namespace mlir::detail;
3636
// PassExecutionAction
3737
//===----------------------------------------------------------------------===//
3838

39+
PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
40+
const Pass &pass)
41+
: Base(irUnits), pass(pass) {}
42+
3943
void PassExecutionAction::print(raw_ostream &os) const {
4044
os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag,
4145
pass.getName(), getOp()->getName());
4246
}
4347

48+
Operation *PassExecutionAction::getOp() const {
49+
ArrayRef<IRUnit> irUnits = getContextIRUnits();
50+
return irUnits.empty() ? nullptr
51+
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
52+
}
53+
4454
//===----------------------------------------------------------------------===//
4555
// Pass
4656
//===----------------------------------------------------------------------===//

mlir/lib/Pass/PassDetail.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,6 @@
1515
#include "llvm/Support/FormatVariadic.h"
1616

1717
namespace mlir {
18-
/// Encapsulate the "action" of executing a single pass, used for the MLIR
19-
/// tracing infrastructure.
20-
struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
21-
using Base = tracing::ActionImpl<PassExecutionAction>;
22-
PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass)
23-
: Base(irUnits), pass(pass) {}
24-
static constexpr StringLiteral tag = "pass-execution";
25-
void print(raw_ostream &os) const override;
26-
const Pass &getPass() const { return pass; }
27-
Operation *getOp() const {
28-
ArrayRef<IRUnit> irUnits = getContextIRUnits();
29-
return irUnits.empty() ? nullptr
30-
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
31-
}
32-
33-
public:
34-
const Pass &pass;
35-
Operation *op;
36-
};
37-
3818
namespace detail {
3919

4020
//===----------------------------------------------------------------------===//

mlir/unittests/Pass/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ add_mlir_unittest(MLIRPassTests
55
)
66
target_link_libraries(MLIRPassTests
77
PRIVATE
8+
MLIRDebug
89
MLIRFuncDialect
910
MLIRPass)

mlir/unittests/Pass/PassManagerTest.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Pass/PassManager.h"
10+
#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11+
#include "mlir/Debug/ExecutionContext.h"
1012
#include "mlir/Dialect/Func/IR/FuncOps.h"
1113
#include "mlir/IR/Builders.h"
1214
#include "mlir/IR/BuiltinOps.h"
@@ -86,6 +88,104 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
8688
}
8789
}
8890

91+
/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92+
struct AddAttrFunctionPass
93+
: public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
94+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
95+
96+
void runOnOperation() override {
97+
func::FuncOp op = getOperation();
98+
Builder builder(op->getParentOfType<ModuleOp>());
99+
if (op->hasAttr("didProcess"))
100+
op->setAttr("didProcessAgain", builder.getUnitAttr());
101+
102+
// We always want to set this one.
103+
op->setAttr("didProcess", builder.getUnitAttr());
104+
}
105+
};
106+
107+
/// Simple pass to annotate a func::FuncOp with a single attribute
108+
/// `didProcess2`.
109+
struct AddSecondAttrFunctionPass
110+
: public PassWrapper<AddSecondAttrFunctionPass,
111+
OperationPass<func::FuncOp>> {
112+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
113+
114+
void runOnOperation() override {
115+
func::FuncOp op = getOperation();
116+
Builder builder(op->getParentOfType<ModuleOp>());
117+
op->setAttr("didProcess2", builder.getUnitAttr());
118+
}
119+
};
120+
121+
TEST(PassManagerTest, ExecutionAction) {
122+
MLIRContext context;
123+
context.loadDialect<func::FuncDialect>();
124+
Builder builder(&context);
125+
126+
// Create a module with 2 functions.
127+
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
128+
auto f =
129+
func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
130+
builder.getFunctionType(std::nullopt, std::nullopt));
131+
f.setPrivate();
132+
module->push_back(f);
133+
134+
// Instantiate our passes.
135+
auto pm = PassManager::on<ModuleOp>(&context);
136+
auto pass = std::make_unique<AddAttrFunctionPass>();
137+
auto *passPtr = pass.get();
138+
pm.addNestedPass<func::FuncOp>(std::move(pass));
139+
pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
140+
// Duplicate the first pass to ensure that we *only* run the *first* pass, not
141+
// all instances of this pass kind. Notice that this pass (and the test as a
142+
// whole) are built to ensure that we can run just a single pass out of a
143+
// pipeline that may contain duplicates.
144+
pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
145+
146+
// Use the action manager to only hit the first pass, not the second one.
147+
auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
148+
-> tracing::ExecutionContext::Control {
149+
// Not a PassExecutionAction, apply the action.
150+
auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
151+
if (!passExec)
152+
return tracing::ExecutionContext::Next;
153+
154+
// If this isn't a function, apply the action.
155+
if (!isa<func::FuncOp>(passExec->getOp()))
156+
return tracing::ExecutionContext::Next;
157+
158+
// Only apply the first function pass. Not all instances of the first pass,
159+
// only the first pass.
160+
if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
161+
return tracing::ExecutionContext::Next;
162+
163+
// Do not apply any other passes in the pass manager.
164+
return tracing::ExecutionContext::Skip;
165+
};
166+
167+
// Set up our breakpoint manager.
168+
tracing::TagBreakpointManager simpleManager;
169+
tracing::ExecutionContext executionCtx(onBreakpoint);
170+
executionCtx.addBreakpointManager(&simpleManager);
171+
simpleManager.addBreakpoint(PassExecutionAction::tag);
172+
173+
// Register the execution context in the MLIRContext.
174+
context.registerActionHandler(executionCtx);
175+
176+
// Run the pass manager, expecting our handler to be called.
177+
LogicalResult result = pm.run(module.get());
178+
EXPECT_TRUE(succeeded(result));
179+
180+
// Verify that each function got annotated with `didProcess` and *not*
181+
// `didProcess2`.
182+
for (func::FuncOp func : module->getOps<func::FuncOp>()) {
183+
ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
184+
ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
185+
ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
186+
}
187+
}
188+
89189
namespace {
90190
struct InvalidPass : Pass {
91191
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)

0 commit comments

Comments
 (0)