Skip to content

[MLIR][python bindings] invalidate ops after PassManager run #69746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void);
///
/// A named attribute is essentially a (name, attribute) pair where the name is
/// a string.

struct MlirNamedAttribute {
MlirIdentifier name;
MlirAttribute attribute;
Expand Down Expand Up @@ -698,6 +697,24 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
/// ownership is transferred to the block of the other operation.
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);

/// Traversal order for operation walk.
typedef enum MlirWalkOrder {
MlirWalkPreOrder,
MlirWalkPostOrder
} MlirWalkOrder;

/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation a pointer to a `userData`.
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);

/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
/// some context or other data into the callback.
MLIR_CAPI_EXPORTED
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder);

//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}

void PyMlirContext::setOperationInvalid(MlirOperation op) {
if (liveOperations.contains(op.ptr))
liveOperations[op.ptr].second->setInvalid();
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

pybind11::object PyMlirContext::contextEnter() {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ class PyMlirContext {
/// place.
size_t clearLiveOperations();

/// Sets an operation invalid. This is useful for when some non-bindings
/// code destroys the operation and the bindings need to made aware. For
/// example, in the case when pass manager is run.
void setOperationInvalid(MlirOperation op);

/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
Expand Down
37 changes: 30 additions & 7 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir-c/Pass.h"

namespace py = pybind11;
using namespace py::literals;
using namespace mlir;
using namespace mlir::python;

Expand Down Expand Up @@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
return new PyPassManager(passManager);
}),
py::arg("anchor_op") = py::str("any"),
py::arg("context") = py::none(),
"anchor_op"_a = py::str("any"), "context"_a = py::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyPassManager::getCapsule)
Expand All @@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
[](PyPassManager &passManager, bool enable) {
mlirPassManagerEnableVerifier(passManager.get(), enable);
},
py::arg("enable"), "Enable / disable verify-each.")
"enable"_a, "Enable / disable verify-each.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
Expand All @@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw py::value_error(std::string(errorMsg.join()));
return new PyPassManager(passManager);
},
py::arg("pipeline"), py::arg("context") = py::none(),
"pipeline"_a, "context"_a = py::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
Expand All @@ -111,20 +111,43 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
if (mlirLogicalResultIsFailure(status))
throw py::value_error(std::string(errorMsg.join()));
},
py::arg("pipeline"),
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
[](PyPassManager &passManager, PyOperationBase &op,
bool invalidateOps) {
if (invalidateOps) {
typedef struct {
PyOperation &rootOp;
bool rootSeen;
} callBackData;
callBackData data{op.getOperation(), false};
// Mark all ops below the op that the passmanager will be rooted
// at (but not op itself - note the preorder) as invalid.
MlirOperationWalkCallback invalidatingCallback =
[](MlirOperation op, void *userData) {
callBackData *data = static_cast<callBackData *>(userData);
if (LLVM_LIKELY(data->rootSeen))
data->rootOp.getOperation()
.getContext()
->setOperationInvalid(op);
else
data->rootSeen = true;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
}
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
py::arg("operation"),
"operation"_a, "invalidate_ops"_a = true,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"

Expand Down Expand Up @@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}

void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder) {
switch (walkOrder) {

case MlirWalkPreOrder:
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
[callback, userData](Operation *op) { callback(wrap(op), userData); });
break;
case MlirWalkPostOrder:
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
[callback, userData](Operation *op) { callback(wrap(op), userData); });
}
}

//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
Expand Down
47 changes: 47 additions & 0 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,51 @@ int testSymbolTable(MlirContext ctx) {
return 0;
}

typedef struct {
const char *x;
} callBackData;

void walkCallBack(MlirOperation op, void *rootOpVoid) {
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
mlirIdentifierStr(mlirOperationGetName(op)).data);
}

int testOperationWalk(MlirContext ctx) {
// CHECK-LABEL: @testOperationWalk
fprintf(stderr, "@testOperationWalk\n");

const char *moduleString = "module {\n"
" func.func @foo() {\n"
" %1 = arith.constant 10: i32\n"
" arith.addi %1, %1: i32\n"
" return\n"
" }\n"
"}";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));

callBackData data;
data.x = "i love you";

// CHECK: i love you: arith.constant
// CHECK: i love you: arith.addi
// CHECK: i love you: func.return
// CHECK: i love you: func.func
// CHECK: i love you: builtin.module
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPostOrder);

data.x = "i don't love you";
// CHECK: i don't love you: builtin.module
// CHECK: i don't love you: func.func
// CHECK: i don't love you: arith.constant
// CHECK: i don't love you: arith.addi
// CHECK: i don't love you: func.return
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPreOrder);
return 0;
}

int testDialectRegistry(void) {
fprintf(stderr, "@testDialectRegistry\n");

Expand Down Expand Up @@ -2349,6 +2394,8 @@ int main(void) {
return 14;
if (testDialectRegistry())
return 15;
if (testOperationWalk(ctx))
return 16;

testExplicitThreadPools();
testDiagnostics();
Expand Down
99 changes: 98 additions & 1 deletion mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp
from mlir.dialects.builtin import ModuleOp


# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
Expand Down Expand Up @@ -33,6 +35,7 @@ def testCapsule():

run(testCapsule)


# CHECK-LABEL: TEST: testConstruct
@run
def testConstruct():
Expand Down Expand Up @@ -68,6 +71,7 @@ def testParseSuccess():

run(testParseSuccess)


# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSpacedPipeline
def testParseSpacedPipeline():
Expand All @@ -84,6 +88,7 @@ def testParseSpacedPipeline():

run(testParseSpacedPipeline)


# Verify failure on unregistered pass.
# CHECK-LABEL: TEST: testParseFail
def testParseFail():
Expand All @@ -102,6 +107,7 @@ def testParseFail():

run(testParseFail)


# Check that adding to a pass manager works
# CHECK-LABEL: TEST: testAdd
@run
Expand Down Expand Up @@ -147,6 +153,7 @@ def testRunPipeline():
# CHECK: func.return , 1
run(testRunPipeline)


# CHECK-LABEL: TEST: testRunPipelineError
@run
def testRunPipelineError():
Expand All @@ -162,4 +169,94 @@ def testRunPipelineError():
# CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
# CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
# CHECK: >
print(f"Exception: <{e}>")
log(f"Exception: <{e}>")


# CHECK-LABEL: TEST: testPostPassOpInvalidation
@run
def testPostPassOpInvalidation():
with Context() as ctx:
module = ModuleOp.parse(
"""
module {
arith.constant 10
func.func @foo() {
arith.constant 10
return
}
}
"""
)

# CHECK: invalidate_ops=False
log("invalidate_ops=False")

outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
log(outer_const_op)

func_op = module.body.operations[1]
# CHECK: func.func @[[FOO:.*]]() {
# CHECK: %[[VAL1:.*]] = arith.constant 10 : i64
# CHECK: return
# CHECK: }
log(func_op)

inner_const_op = func_op.body.blocks[0].operations[0]
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)

PassManager.parse("builtin.module(canonicalize)").run(
module, invalidate_ops=False
)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
log(func_op)

# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
log(module)

# CHECK: invalidate_ops=True
log("invalidate_ops=True")

module = ModuleOp.parse(
"""
module {
arith.constant 10
func.func @foo() {
arith.constant 10
return
}
}
"""
)
outer_const_op = module.body.operations[0]
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]

PassManager.parse("builtin.module(canonicalize)").run(module)
try:
log(func_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)

try:
log(outer_const_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)

try:
log(inner_const_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)

# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
log(module)