-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][python bindings] invalidate ops after PassManager run #69746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b0ca636
to
d214b4f
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesFixes #69730 (also see https://reviews.llvm.org/D155543). There are two things outstanding (why I didn't land before):
Supposing there's no desire for the slightly more efficient but highly convoluted approach, I can land this "posthaste". Full diff: https://github.com/llvm/llvm-project/pull/69746.diff 6 Files Affected:
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e361f33a0d83641..3163c3cc40c58b1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -698,6 +698,21 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
/// ownership is transferred to the block of the other operation.
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);
+
+typedef enum MlirWalkOrder {
+ MlirWalkPreOrder,
+ MlirWalkPostOrder
+} MlirWalkOrder;
+
+typedef void (*MlirOperationWalkCallback)(MlirOperation, void *);
+
+/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
+/// `*userData` is passed to the callback as well and can be used to tunnel some
+/// some context or other data into the callback.
+MLIR_CAPI_EXPORTED
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+ void *userData, MlirWalkOrder walkOrder);
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 389a4621c14e594..a8ea1a381edb96e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}
+void PyMlirContext::setOperationInvalid(MlirOperation op) {
+ if (liveOperations.contains(op.ptr))
+ liveOperations[op.ptr].second->setInvalid();
+}
+
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
pybind11::object PyMlirContext::contextEnter() {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index c5412e735dddcb5..26292885711a4e4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,6 +209,11 @@ class PyMlirContext {
/// place.
size_t clearLiveOperations();
+ /// Sets an operation invalid. This is useful for when some non-bindings
+ /// code destroys the operation and the bindings need to made aware. For
+ /// example, in the case when pass manager is run.
+ void setOperationInvalid(MlirOperation op);
+
/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc22957a6..5e7a918d1bad55b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -13,6 +13,7 @@
#include "mlir-c/Pass.h"
namespace py = pybind11;
+using namespace py::literals;
using namespace mlir;
using namespace mlir::python;
@@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
return new PyPassManager(passManager);
}),
- py::arg("anchor_op") = py::str("any"),
- py::arg("context") = py::none(),
+ "anchor_op"_a = py::str("any"), "context"_a = py::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyPassManager::getCapsule)
@@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
[](PyPassManager &passManager, bool enable) {
mlirPassManagerEnableVerifier(passManager.get(), enable);
},
- py::arg("enable"), "Enable / disable verify-each.")
+ "enable"_a, "Enable / disable verify-each.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw py::value_error(std::string(errorMsg.join()));
return new PyPassManager(passManager);
},
- py::arg("pipeline"), py::arg("context") = py::none(),
+ "pipeline"_a, "context"_a = py::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
@@ -111,12 +111,30 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
if (mlirLogicalResultIsFailure(status))
throw py::value_error(std::string(errorMsg.join()));
},
- py::arg("pipeline"),
+ "pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"run",
- [](PyPassManager &passManager, PyOperationBase &op) {
+ [](PyPassManager &passManager, PyOperationBase &op,
+ bool invalidateOps) {
+ if (invalidateOps) {
+ // Mark all ops below the op that the passmanager will be rooted
+ // at as invalid.
+ MlirOperationWalkCallback invalidatingCallback =
+ [](MlirOperation op, void *rootOpVoid) {
+ PyOperation *rootOp =
+ static_cast<PyOperation *>(rootOpVoid);
+ if (!mlirOperationEqual(rootOp->get(), op)) {
+ rootOp->getOperation().getContext()->setOperationInvalid(
+ op);
+ }
+ };
+ mlirOperationWalk(op.getOperation(), invalidatingCallback,
+ static_cast<void *>(&op.getOperation()),
+ MlirWalkPostOrder);
+ }
+ // Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
@@ -124,7 +142,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
- py::arg("operation"),
+ "operation"_a, "invalidate_ops"_a = true,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c1abbbe364611af..0a5151751873f2b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
@@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+ void *userData, MlirWalkOrder walkOrder) {
+ switch (walkOrder) {
+
+ case MlirWalkPreOrder:
+ unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
+ [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ break;
+ case MlirWalkPostOrder:
+ unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
+ [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ }
+}
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 4b3a02ac42bd9b1..e7f79ddc75113e0 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -4,6 +4,8 @@
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp
+from mlir.dialects.builtin import ModuleOp
+
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
@@ -33,6 +35,7 @@ def testCapsule():
run(testCapsule)
+
# CHECK-LABEL: TEST: testConstruct
@run
def testConstruct():
@@ -68,6 +71,7 @@ def testParseSuccess():
run(testParseSuccess)
+
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSpacedPipeline
def testParseSpacedPipeline():
@@ -84,6 +88,7 @@ def testParseSpacedPipeline():
run(testParseSpacedPipeline)
+
# Verify failure on unregistered pass.
# CHECK-LABEL: TEST: testParseFail
def testParseFail():
@@ -102,6 +107,7 @@ def testParseFail():
run(testParseFail)
+
# Check that adding to a pass manager works
# CHECK-LABEL: TEST: testAdd
@run
@@ -147,6 +153,7 @@ def testRunPipeline():
# CHECK: func.return , 1
run(testRunPipeline)
+
# CHECK-LABEL: TEST: testRunPipelineError
@run
def testRunPipelineError():
@@ -162,4 +169,94 @@ def testRunPipelineError():
# CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
# CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
# CHECK: >
- print(f"Exception: <{e}>")
+ log(f"Exception: <{e}>")
+
+
+# CHECK-LABEL: TEST: testPostPassOpInvalidation
+@run
+def testPostPassOpInvalidation():
+ with Context() as ctx:
+ module = ModuleOp.parse(
+ """
+ module {
+ arith.constant 10
+ func.func @foo() {
+ arith.constant 10
+ return
+ }
+ }
+ """
+ )
+
+ # CHECK: invalidate_ops=False
+ log("invalidate_ops=False")
+
+ outer_const_op = module.body.operations[0]
+ # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
+ log(outer_const_op)
+
+ func_op = module.body.operations[1]
+ # CHECK: func.func @[[FOO:.*]]() {
+ # CHECK: %[[VAL1:.*]] = arith.constant 10 : i64
+ # CHECK: return
+ # CHECK: }
+ log(func_op)
+
+ inner_const_op = func_op.body.blocks[0].operations[0]
+ # CHECK: %[[VAL1]] = arith.constant 10 : i64
+ log(inner_const_op)
+
+ PassManager.parse("builtin.module(canonicalize)").run(
+ module, invalidate_ops=False
+ )
+ # CHECK: func.func @foo() {
+ # CHECK: return
+ # CHECK: }
+ log(func_op)
+
+ # CHECK: func.func @foo() {
+ # CHECK: return
+ # CHECK: }
+ log(module)
+
+ # CHECK: invalidate_ops=True
+ log("invalidate_ops=True")
+
+ module = ModuleOp.parse(
+ """
+ module {
+ arith.constant 10
+ func.func @foo() {
+ arith.constant 10
+ return
+ }
+ }
+ """
+ )
+ outer_const_op = module.body.operations[0]
+ func_op = module.body.operations[1]
+ inner_const_op = func_op.body.blocks[0].operations[0]
+
+ PassManager.parse("builtin.module(canonicalize)").run(module)
+ try:
+ log(func_op)
+ except RuntimeError as e:
+ # CHECK: the operation has been invalidated
+ log(e)
+
+ try:
+ log(outer_const_op)
+ except RuntimeError as e:
+ # CHECK: the operation has been invalidated
+ log(e)
+
+ try:
+ log(inner_const_op)
+ except RuntimeError as e:
+ # CHECK: the operation has been invalidated
+ log(e)
+
+ # CHECK: func.func @foo() {
+ # CHECK: return
+ # CHECK: }
+ log(module)
|
ae8f1c8
to
0636488
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Stupid suggestion: if we could "revalidate" ops, we can invalidate everything including the root and than revalidate the root back.
but if "validation" becomes reversible then it might be wise to put One alternative is to make the walk |
27dffe2
to
8b5c99c
Compare
8b5c99c
to
7abf859
Compare
So invalidating is default on and only in cases where (expert user) knows it's safe can they opt out? (Reading on mobile so high chance of misreading :-)) |
Yup "fail-dangerous" only if you're a professional 🙂. |
`PyOperations` are Python-level handles to `Operation *` instances. When the latter are modified by C++, the former need to be invalidated. llvm#69746 implements such invalidation mechanism by setting all `PyReferences` to `invalid`. However, that is not enough: they also need to be removed from the `liveOperations` map since other parts of the code (such as `PyOperation::createDetached`) assume that that map only contains valid refs. This is required to actually solve the issue in llvm#69730.
`PyOperations` are Python-level handles to `Operation *` instances. When the latter are modified by C++, the former need to be invalidated. #69746 implements such invalidation mechanism by setting all `PyReferences` to `invalid`. However, that is not enough: they also need to be removed from the `liveOperations` map since other parts of the code (such as `PyOperation::createDetached`) assume that that map only contains valid refs. This is required to actually solve the issue in #69730.
Fixes #69730 (also see https://reviews.llvm.org/D155543).
There are two things outstanding (why I didn't land before):
mlirOperationWalk
;run
works; the first version of the code looked like this:mlirOperationEqual(rootOp->get(), op)
for every op underneath the root op.Supposing there's no desire for the slightly more efficient but highly convoluted approach, I can land this "posthaste".
But, since we have eyes on this now, any suggestions or approaches (or needs/concerns) are welcome.