|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
9 | 9 | #include "mlir/Pass/PassManager.h"
|
| 10 | +#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h" |
| 11 | +#include "mlir/Debug/ExecutionContext.h" |
10 | 12 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
11 | 13 | #include "mlir/IR/Builders.h"
|
12 | 14 | #include "mlir/IR/BuiltinOps.h"
|
@@ -86,6 +88,104 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
|
86 | 88 | }
|
87 | 89 | }
|
88 | 90 |
|
| 91 | +/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`. |
| 92 | +struct AddAttrFunctionPass |
| 93 | + : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> { |
| 94 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass) |
| 95 | + |
| 96 | + void runOnOperation() override { |
| 97 | + func::FuncOp op = getOperation(); |
| 98 | + Builder builder(op->getParentOfType<ModuleOp>()); |
| 99 | + if (op->hasAttr("didProcess")) |
| 100 | + op->setAttr("didProcessAgain", builder.getUnitAttr()); |
| 101 | + |
| 102 | + // We always want to set this one. |
| 103 | + op->setAttr("didProcess", builder.getUnitAttr()); |
| 104 | + } |
| 105 | +}; |
| 106 | + |
| 107 | +/// Simple pass to annotate a func::FuncOp with a single attribute |
| 108 | +/// `didProcess2`. |
| 109 | +struct AddSecondAttrFunctionPass |
| 110 | + : public PassWrapper<AddSecondAttrFunctionPass, |
| 111 | + OperationPass<func::FuncOp>> { |
| 112 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass) |
| 113 | + |
| 114 | + void runOnOperation() override { |
| 115 | + func::FuncOp op = getOperation(); |
| 116 | + Builder builder(op->getParentOfType<ModuleOp>()); |
| 117 | + op->setAttr("didProcess2", builder.getUnitAttr()); |
| 118 | + } |
| 119 | +}; |
| 120 | + |
| 121 | +TEST(PassManagerTest, ExecutionAction) { |
| 122 | + MLIRContext context; |
| 123 | + context.loadDialect<func::FuncDialect>(); |
| 124 | + Builder builder(&context); |
| 125 | + |
| 126 | + // Create a module with 2 functions. |
| 127 | + OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context))); |
| 128 | + auto f = |
| 129 | + func::FuncOp::create(builder.getUnknownLoc(), "process_me_once", |
| 130 | + builder.getFunctionType(std::nullopt, std::nullopt)); |
| 131 | + f.setPrivate(); |
| 132 | + module->push_back(f); |
| 133 | + |
| 134 | + // Instantiate our passes. |
| 135 | + auto pm = PassManager::on<ModuleOp>(&context); |
| 136 | + auto pass = std::make_unique<AddAttrFunctionPass>(); |
| 137 | + auto *passPtr = pass.get(); |
| 138 | + pm.addNestedPass<func::FuncOp>(std::move(pass)); |
| 139 | + pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>()); |
| 140 | + // Duplicate the first pass to ensure that we *only* run the *first* pass, not |
| 141 | + // all instances of this pass kind. Notice that this pass (and the test as a |
| 142 | + // whole) are built to ensure that we can run just a single pass out of a |
| 143 | + // pipeline that may contain duplicates. |
| 144 | + pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>()); |
| 145 | + |
| 146 | + // Use the action manager to only hit the first pass, not the second one. |
| 147 | + auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace) |
| 148 | + -> tracing::ExecutionContext::Control { |
| 149 | + // Not a PassExecutionAction, apply the action. |
| 150 | + auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction()); |
| 151 | + if (!passExec) |
| 152 | + return tracing::ExecutionContext::Next; |
| 153 | + |
| 154 | + // If this isn't a function, apply the action. |
| 155 | + if (!isa<func::FuncOp>(passExec->getOp())) |
| 156 | + return tracing::ExecutionContext::Next; |
| 157 | + |
| 158 | + // Only apply the first function pass. Not all instances of the first pass, |
| 159 | + // only the first pass. |
| 160 | + if (passExec->getPass().getThreadingSiblingOrThis() == passPtr) |
| 161 | + return tracing::ExecutionContext::Next; |
| 162 | + |
| 163 | + // Do not apply any other passes in the pass manager. |
| 164 | + return tracing::ExecutionContext::Skip; |
| 165 | + }; |
| 166 | + |
| 167 | + // Set up our breakpoint manager. |
| 168 | + tracing::TagBreakpointManager simpleManager; |
| 169 | + tracing::ExecutionContext executionCtx(onBreakpoint); |
| 170 | + executionCtx.addBreakpointManager(&simpleManager); |
| 171 | + simpleManager.addBreakpoint(PassExecutionAction::tag); |
| 172 | + |
| 173 | + // Register the execution context in the MLIRContext. |
| 174 | + context.registerActionHandler(executionCtx); |
| 175 | + |
| 176 | + // Run the pass manager, expecting our handler to be called. |
| 177 | + LogicalResult result = pm.run(module.get()); |
| 178 | + EXPECT_TRUE(succeeded(result)); |
| 179 | + |
| 180 | + // Verify that each function got annotated with `didProcess` and *not* |
| 181 | + // `didProcess2`. |
| 182 | + for (func::FuncOp func : module->getOps<func::FuncOp>()) { |
| 183 | + ASSERT_TRUE(func->getDiscardableAttr("didProcess")); |
| 184 | + ASSERT_FALSE(func->getDiscardableAttr("didProcess2")); |
| 185 | + ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain")); |
| 186 | + } |
| 187 | +} |
| 188 | + |
89 | 189 | namespace {
|
90 | 190 | struct InvalidPass : Pass {
|
91 | 191 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
|
|
0 commit comments