Skip to content

[mlir][Pass] Move PassExecutionAction to Pass.h, NFC. #74850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mlir/include/mlir/Pass/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H

#include "mlir/IR/Action.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
Expand Down Expand Up @@ -457,6 +458,52 @@ class PassWrapper : public BaseT {
}
};

/// This class encapsulates the "action" of executing a single pass. This allows
/// a user of the Action infrastructure to query information about an action in
/// (for example) a breakpoint context. You could use it like this:
///
/// auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
/// if (auto passExec = dyn_cast<PassExecutionAction>(anAction))
/// record(passExec.getPass());
/// return ExecutionContext::Apply;
/// };
/// ExecutionContext exeCtx(onBreakpoint);
///
class PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
using Base = tracing::ActionImpl<PassExecutionAction>;

public:
/// Define a TypeID for this PassExecutionAction.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PassExecutionAction)
/// Construct a PassExecutionAction. This is called by the OpToOpPassAdaptor
/// when it calls `executeAction`.
PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass);

/// The tag required by ActionImpl to identify this action.
static constexpr StringLiteral tag = "pass-execution";

/// Print a textual version of this action to `os`.
void print(raw_ostream &os) const override;

/// Get the pass that will be executed by this action. This is not a class of
/// passes, or all instances of a pass kind, this is a single pass.
const Pass &getPass() const { return pass; }

/// Get the operation that is the base of this pass. For example, an
/// OperationPass<ModuleOp> would return a ModuleOp.
Operation *getOp() const;

public:
/// Reference to the pass being run. Notice that this will *not* extend the
/// lifetime of the pass, and so this class is therefore unsafe to keep past
/// the lifetime of the `executeAction` call.
const Pass &pass;

/// The base op for this pass. For an OperationPass<ModuleOp>, we would have a
/// ModuleOp here.
Operation *op;
};

} // namespace mlir

#endif // MLIR_PASS_PASS_H
10 changes: 10 additions & 0 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,21 @@ using namespace mlir::detail;
// PassExecutionAction
//===----------------------------------------------------------------------===//

PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
const Pass &pass)
: Base(irUnits), pass(pass) {}

void PassExecutionAction::print(raw_ostream &os) const {
os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag,
pass.getName(), getOp()->getName());
}

Operation *PassExecutionAction::getOp() const {
ArrayRef<IRUnit> irUnits = getContextIRUnits();
return irUnits.empty() ? nullptr
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
}

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 0 additions & 20 deletions mlir/lib/Pass/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,6 @@
#include "llvm/Support/FormatVariadic.h"

namespace mlir {
/// Encapsulate the "action" of executing a single pass, used for the MLIR
/// tracing infrastructure.
struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
using Base = tracing::ActionImpl<PassExecutionAction>;
PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass)
: Base(irUnits), pass(pass) {}
static constexpr StringLiteral tag = "pass-execution";
void print(raw_ostream &os) const override;
const Pass &getPass() const { return pass; }
Operation *getOp() const {
ArrayRef<IRUnit> irUnits = getContextIRUnits();
return irUnits.empty() ? nullptr
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
}

public:
const Pass &pass;
Operation *op;
};

namespace detail {

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ add_mlir_unittest(MLIRPassTests
)
target_link_libraries(MLIRPassTests
PRIVATE
MLIRDebug
MLIRFuncDialect
MLIRPass)
100 changes: 100 additions & 0 deletions mlir/unittests/Pass/PassManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Pass/PassManager.h"
#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -86,6 +88,104 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
}
}

/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
struct AddAttrFunctionPass
: public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)

void runOnOperation() override {
func::FuncOp op = getOperation();
Builder builder(op->getParentOfType<ModuleOp>());
if (op->hasAttr("didProcess"))
op->setAttr("didProcessAgain", builder.getUnitAttr());

// We always want to set this one.
op->setAttr("didProcess", builder.getUnitAttr());
}
};

/// Simple pass to annotate a func::FuncOp with a single attribute
/// `didProcess2`.
struct AddSecondAttrFunctionPass
: public PassWrapper<AddSecondAttrFunctionPass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)

void runOnOperation() override {
func::FuncOp op = getOperation();
Builder builder(op->getParentOfType<ModuleOp>());
op->setAttr("didProcess2", builder.getUnitAttr());
}
};

TEST(PassManagerTest, ExecutionAction) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
Builder builder(&context);

// Create a module with 2 functions.
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
auto f =
func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
builder.getFunctionType(std::nullopt, std::nullopt));
f.setPrivate();
module->push_back(f);

// Instantiate our passes.
auto pm = PassManager::on<ModuleOp>(&context);
auto pass = std::make_unique<AddAttrFunctionPass>();
auto *passPtr = pass.get();
pm.addNestedPass<func::FuncOp>(std::move(pass));
pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
// Duplicate the first pass to ensure that we *only* run the *first* pass, not
// all instances of this pass kind. Notice that this pass (and the test as a
// whole) are built to ensure that we can run just a single pass out of a
// pipeline that may contain duplicates.
pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());

// Use the action manager to only hit the first pass, not the second one.
auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
-> tracing::ExecutionContext::Control {
// Not a PassExecutionAction, apply the action.
auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
if (!passExec)
return tracing::ExecutionContext::Next;

// If this isn't a function, apply the action.
if (!isa<func::FuncOp>(passExec->getOp()))
return tracing::ExecutionContext::Next;

// Only apply the first function pass. Not all instances of the first pass,
// only the first pass.
if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
return tracing::ExecutionContext::Next;

// Do not apply any other passes in the pass manager.
return tracing::ExecutionContext::Skip;
};

// Set up our breakpoint manager.
tracing::TagBreakpointManager simpleManager;
tracing::ExecutionContext executionCtx(onBreakpoint);
executionCtx.addBreakpointManager(&simpleManager);
simpleManager.addBreakpoint(PassExecutionAction::tag);

// Register the execution context in the MLIRContext.
context.registerActionHandler(executionCtx);

// Run the pass manager, expecting our handler to be called.
LogicalResult result = pm.run(module.get());
EXPECT_TRUE(succeeded(result));

// Verify that each function got annotated with `didProcess` and *not*
// `didProcess2`.
for (func::FuncOp func : module->getOps<func::FuncOp>()) {
ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
}
}

namespace {
struct InvalidPass : Pass {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
Expand Down