Skip to content

[MLIR][python bindings] invalidate ops after PassManager run #69746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 21, 2023

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Oct 20, 2023

Fixes #69730 (also see https://reviews.llvm.org/D155543).

There are two things outstanding (why I didn't land before):

  1. add some C API tests for mlirOperationWalk;
  2. potentially refactor how the invalidation in run works; the first version of the code looked like this:
    if (invalidateOps) {
      auto *context = op.getOperation().getContext().get();
      MlirOperationWalkCallback invalidatingCallback =
          [](MlirOperation op, void *userData) {
            PyMlirContext *context =
                static_cast<PyMlirContext *>(userData);
            context->setOperationInvalid(op);
          };
      auto numRegions =
          mlirOperationGetNumRegions(op.getOperation().get());
      for (int i = 0; i < numRegions; ++i) {
        MlirRegion region =
            mlirOperationGetRegion(op.getOperation().get(), i);
        for (MlirBlock block = mlirRegionGetFirstBlock(region);
             !mlirBlockIsNull(block);
             block = mlirBlockGetNextInRegion(block))
          for (MlirOperation childOp =
                   mlirBlockGetFirstOperation(block);
               !mlirOperationIsNull(childOp);
               childOp = mlirOperationGetNextInBlock(childOp))
            mlirOperationWalk(childOp, invalidatingCallback, context,
                              MlirWalkPostOrder);
      }
    }
    This is verbose and ugly but it has the important benefit of not executing mlirOperationEqual(rootOp->get(), op) for every op underneath the root op.

Supposing there's no desire for the slightly more efficient but highly convoluted approach, I can land this "posthaste".
But, since we have eyes on this now, any suggestions or approaches (or needs/concerns) are welcome.

@makslevental makslevental force-pushed the invalidate_ops branch 2 times, most recently from b0ca636 to d214b4f Compare October 20, 2023 18:38
@makslevental makslevental marked this pull request as ready for review October 20, 2023 18:46
@llvmbot llvmbot added the mlir label Oct 20, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Fixes #69730 (also see https://reviews.llvm.org/D155543).

There are two things outstanding (why I didn't land before):

  1. add some C API tests for mlirOperationWalk;
  2. potentially refactor how the invalidation in run works; the first version of the code looked like this:
    if (invalidateOps) {
      auto *context = op.getOperation().getContext().get();
      MlirOperationWalkCallback invalidatingCallback =
          [](MlirOperation op, void *userData) {
            PyMlirContext *context =
                static_cast&lt;PyMlirContext *&gt;(userData);
            context-&gt;setOperationInvalid(op);
          };
      auto numRegions =
          mlirOperationGetNumRegions(op.getOperation().get());
      for (int i = 0; i &lt; numRegions; ++i) {
        MlirRegion region =
            mlirOperationGetRegion(op.getOperation().get(), i);
        for (MlirBlock block = mlirRegionGetFirstBlock(region);
             !mlirBlockIsNull(block);
             block = mlirBlockGetNextInRegion(block))
          for (MlirOperation childOp =
                   mlirBlockGetFirstOperation(block);
               !mlirOperationIsNull(childOp);
               childOp = mlirOperationGetNextInBlock(childOp))
            mlirOperationWalk(childOp, invalidatingCallback, context,
                              MlirWalkPostOrder);
      }
    }
    This is verbose and ugly but it has the important benefit of not mlirOperationEqual(rootOp-&gt;get(), op) for every op underneath the root op.

Supposing there's no desire for the slightly more efficient but highly convoluted approach, I can land this "posthaste".
But, since we have eyes on this now, any suggestions or approaches (or needs/concerns) are welcome.


Full diff: https://github.com/llvm/llvm-project/pull/69746.diff

6 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+15)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+5)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+5)
  • (modified) mlir/lib/Bindings/Python/Pass.cpp (+25-7)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+15)
  • (modified) mlir/test/python/pass_manager.py (+98-1)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e361f33a0d83641..3163c3cc40c58b1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -698,6 +698,21 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
 /// ownership is transferred to the block of the other operation.
 MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
                                                 MlirOperation other);
+
+typedef enum MlirWalkOrder {
+  MlirWalkPreOrder,
+  MlirWalkPostOrder
+} MlirWalkOrder;
+
+typedef void (*MlirOperationWalkCallback)(MlirOperation, void *);
+
+/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
+/// `*userData` is passed to the callback as well and can be used to tunnel some
+/// some context or other data into the callback.
+MLIR_CAPI_EXPORTED
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+                       void *userData, MlirWalkOrder walkOrder);
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 389a4621c14e594..a8ea1a381edb96e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
   return numInvalidated;
 }
 
+void PyMlirContext::setOperationInvalid(MlirOperation op) {
+  if (liveOperations.contains(op.ptr))
+    liveOperations[op.ptr].second->setInvalid();
+}
+
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 pybind11::object PyMlirContext::contextEnter() {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index c5412e735dddcb5..26292885711a4e4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,6 +209,11 @@ class PyMlirContext {
   /// place.
   size_t clearLiveOperations();
 
+  /// Sets an operation invalid. This is useful for when some non-bindings
+  /// code destroys the operation and the bindings need to made aware. For
+  /// example, in the case when pass manager is run.
+  void setOperationInvalid(MlirOperation op);
+
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
   size_t getLiveModuleCount();
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc22957a6..5e7a918d1bad55b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -13,6 +13,7 @@
 #include "mlir-c/Pass.h"
 
 namespace py = pybind11;
+using namespace py::literals;
 using namespace mlir;
 using namespace mlir::python;
 
@@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
              return new PyPassManager(passManager);
            }),
-           py::arg("anchor_op") = py::str("any"),
-           py::arg("context") = py::none(),
+           "anchor_op"_a = py::str("any"), "context"_a = py::none(),
            "Create a new PassManager for the current (or provided) Context.")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyPassManager::getCapsule)
@@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           [](PyPassManager &passManager, bool enable) {
             mlirPassManagerEnableVerifier(passManager.get(), enable);
           },
-          py::arg("enable"), "Enable / disable verify-each.")
+          "enable"_a, "Enable / disable verify-each.")
       .def_static(
           "parse",
           [](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
               throw py::value_error(std::string(errorMsg.join()));
             return new PyPassManager(passManager);
           },
-          py::arg("pipeline"), py::arg("context") = py::none(),
+          "pipeline"_a, "context"_a = py::none(),
           "Parse a textual pass-pipeline and return a top-level PassManager "
           "that can be applied on a Module. Throw a ValueError if the pipeline "
           "can't be parsed")
@@ -111,12 +111,30 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
             if (mlirLogicalResultIsFailure(status))
               throw py::value_error(std::string(errorMsg.join()));
           },
-          py::arg("pipeline"),
+          "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyOperationBase &op) {
+          [](PyPassManager &passManager, PyOperationBase &op,
+             bool invalidateOps) {
+            if (invalidateOps) {
+              // Mark all ops below the op that the passmanager will be rooted
+              // at as invalid.
+              MlirOperationWalkCallback invalidatingCallback =
+                  [](MlirOperation op, void *rootOpVoid) {
+                    PyOperation *rootOp =
+                        static_cast<PyOperation *>(rootOpVoid);
+                    if (!mlirOperationEqual(rootOp->get(), op)) {
+                      rootOp->getOperation().getContext()->setOperationInvalid(
+                          op);
+                    }
+                  };
+              mlirOperationWalk(op.getOperation(), invalidatingCallback,
+                                static_cast<void *>(&op.getOperation()),
+                                MlirWalkPostOrder);
+            }
+            // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
                 passManager.get(), op.getOperation().get());
@@ -124,7 +142,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
               throw MLIRError("Failure while executing pass pipeline",
                               errors.take());
           },
-          py::arg("operation"),
+          "operation"_a, "invalidate_ops"_a = true,
           "Run the pass manager on the provided operation, raising an "
           "MLIRError on failure.")
       .def(
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c1abbbe364611af..0a5151751873f2b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
 
@@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
   return unwrap(op)->moveBefore(unwrap(other));
 }
 
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+                       void *userData, MlirWalkOrder walkOrder) {
+  switch (walkOrder) {
+
+  case MlirWalkPreOrder:
+    unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
+        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+    break;
+  case MlirWalkPostOrder:
+    unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
+        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 4b3a02ac42bd9b1..e7f79ddc75113e0 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -4,6 +4,8 @@
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.func import FuncOp
+from mlir.dialects.builtin import ModuleOp
+
 
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
@@ -33,6 +35,7 @@ def testCapsule():
 
 run(testCapsule)
 
+
 # CHECK-LABEL: TEST: testConstruct
 @run
 def testConstruct():
@@ -68,6 +71,7 @@ def testParseSuccess():
 
 run(testParseSuccess)
 
+
 # Verify successful round-trip.
 # CHECK-LABEL: TEST: testParseSpacedPipeline
 def testParseSpacedPipeline():
@@ -84,6 +88,7 @@ def testParseSpacedPipeline():
 
 run(testParseSpacedPipeline)
 
+
 # Verify failure on unregistered pass.
 # CHECK-LABEL: TEST: testParseFail
 def testParseFail():
@@ -102,6 +107,7 @@ def testParseFail():
 
 run(testParseFail)
 
+
 # Check that adding to a pass manager works
 # CHECK-LABEL: TEST: testAdd
 @run
@@ -147,6 +153,7 @@ def testRunPipeline():
 # CHECK: func.return        , 1
 run(testRunPipeline)
 
+
 # CHECK-LABEL: TEST: testRunPipelineError
 @run
 def testRunPipelineError():
@@ -162,4 +169,94 @@ def testRunPipelineError():
             # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
             # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
             # CHECK: >
-            print(f"Exception: <{e}>")
+            log(f"Exception: <{e}>")
+
+
+# CHECK-LABEL: TEST: testPostPassOpInvalidation
+@run
+def testPostPassOpInvalidation():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            arith.constant 10
+            func.func @foo() {
+              arith.constant 10
+              return
+            }
+          }
+        """
+        )
+
+        # CHECK: invalidate_ops=False
+        log("invalidate_ops=False")
+
+        outer_const_op = module.body.operations[0]
+        # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
+        log(outer_const_op)
+
+        func_op = module.body.operations[1]
+        # CHECK: func.func @[[FOO:.*]]() {
+        # CHECK:   %[[VAL1:.*]] = arith.constant 10 : i64
+        # CHECK:   return
+        # CHECK: }
+        log(func_op)
+
+        inner_const_op = func_op.body.blocks[0].operations[0]
+        # CHECK: %[[VAL1]] = arith.constant 10 : i64
+        log(inner_const_op)
+
+        PassManager.parse("builtin.module(canonicalize)").run(
+            module, invalidate_ops=False
+        )
+        # CHECK: func.func @foo() {
+        # CHECK:   return
+        # CHECK: }
+        log(func_op)
+
+        # CHECK: func.func @foo() {
+        # CHECK:   return
+        # CHECK: }
+        log(module)
+
+        # CHECK: invalidate_ops=True
+        log("invalidate_ops=True")
+
+        module = ModuleOp.parse(
+            """
+          module {
+            arith.constant 10
+            func.func @foo() {
+              arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        outer_const_op = module.body.operations[0]
+        func_op = module.body.operations[1]
+        inner_const_op = func_op.body.blocks[0].operations[0]
+
+        PassManager.parse("builtin.module(canonicalize)").run(module)
+        try:
+            log(func_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        try:
+            log(outer_const_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        try:
+            log(inner_const_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        # CHECK: func.func @foo() {
+        # CHECK:   return
+        # CHECK: }
+        log(module)

@makslevental makslevental force-pushed the invalidate_ops branch 2 times, most recently from ae8f1c8 to 0636488 Compare October 20, 2023 19:20
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Stupid suggestion: if we could "revalidate" ops, we can invalidate everything including the root and than revalidate the root back.

@makslevental
Copy link
Contributor Author

Stupid suggestion: if we could "revalidate" ops, we can invalidate everything including the root and than revalidate the root back.

image

but if "validation" becomes reversible then it might be wise to put LiveOperationMap liveOperations behind a mutex? Just so someone doesn't accidentally create a race condition at some point down the line?

One alternative is to make the walk MlirWalkPreOrder, do nothing but set a flag like rootSeen = true, and from then on check that flag instead of performing mlirOperationEqual. I should've done that in the first place I think...

@makslevental makslevental force-pushed the invalidate_ops branch 3 times, most recently from 27dffe2 to 8b5c99c Compare October 20, 2023 22:51
@makslevental makslevental merged commit bdc3e6c into llvm:main Oct 21, 2023
@makslevental makslevental deleted the invalidate_ops branch October 21, 2023 01:28
@jpienaar
Copy link
Member

So invalidating is default on and only in cases where (expert user) knows it's safe can they opt out? (Reading on mobile so high chance of misreading :-))

@makslevental
Copy link
Contributor Author

So invalidating is default on and only in cases where (expert user) knows it's safe can they opt out? (Reading on mobile so high chance of misreading :-))

Yup "fail-dangerous" only if you're a professional 🙂.

ingomueller-net added a commit to ingomueller-net/llvm-project that referenced this pull request Oct 24, 2023
`PyOperations` are Python-level handles to `Operation *` instances. When
the latter are modified by C++, the former need to be invalidated.
 llvm#69746 implements such invalidation mechanism by setting all
`PyReferences` to `invalid`. However, that is not enough: they also need
to be removed from the `liveOperations` map since other parts of the
code (such as `PyOperation::createDetached`) assume that that map only
contains valid refs.

This is required to actually solve the issue in llvm#69730.
ingomueller-net added a commit to ingomueller-net/llvm-project that referenced this pull request Oct 24, 2023
ingomueller-net added a commit that referenced this pull request Oct 25, 2023
`PyOperations` are Python-level handles to `Operation *` instances. When
the latter are modified by C++, the former need to be invalidated.
#69746 implements such invalidation mechanism by setting all
`PyReferences` to `invalid`. However, that is not enough: they also need
to be removed from the `liveOperations` map since other parts of the
code (such as `PyOperation::createDetached`) assume that that map only
contains valid refs.

This is required to actually solve the issue in #69730.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][python] Tracking of live operations fails.
4 participants