Skip to content

[mlir][python] Add walk method to PyOperationBase #87962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 17, 2024

Conversation

uenoku
Copy link
Member

@uenoku uenoku commented Apr 8, 2024

This commit adds walk method to PyOperationBase that uses a python object as a callback, e.g. op.walk(callback). Currently callback must return a walk result explicitly.

We(SiFive) have implemented walk method with python in our internal python tool for a while. However the overhead of python is expensive and it didn't scale well for large MLIR files. Just replacing walk with this version reduced the entire execution time of the tool by 30~40% and there are a few configs that the tool takes several hours to finish so this commit significantly improves tool performance.

This commit adds `walk` method that uses a python object as
a callback
@llvmbot llvmbot added the mlir label Apr 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2024

@llvm/pr-subscribers-mlir

Author: Hideto Ueno (uenoku)

Changes

This commit adds walk method t PyOperationBase that uses a python object as a callback, e.g. op.walk(lambda op: print(op)). The second optional argument is a boolean that specifies walk order.


Full diff: https://github.com/llvm/llvm-project/pull/87962.diff

3 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+17-1)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+3)
  • (modified) mlir/test/python/ir/operation.py (+32)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
                               .str());
 }
 
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  MlirOperationWalkCallback walkCallback =
+   [](MlirOperation op,
+                                              void *userData) {
+    py::object *fn = static_cast<py::object *>(userData);
+    (*fn)(op);
+  };
+  mlirOperationWalk(operation, walkCallback, &callback,
+                    usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+                                : MlirWalkOrder::MlirWalkPostOrder);
+}
+
 py::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
             return operation.createOpView();
           },
           "Detaches the operation from its parent block.")
-      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+      .def("walk", &PyOperationBase::walk, py::arg("callback"),
+           py::arg("use_pre_order") = py::bool_(false));
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
   void writeBytecode(const pybind11::object &fileObject,
                      std::optional<int64_t> bytecodeVersion);
 
+  // Implement the walk method.
+  void walk(pybind11::object callback, bool usePreOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
         print(
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
+
+# CHECK-LABEL: TEST: testOpWalk
+@run
+def testOpWalk():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
+    builtin.module {
+      func.func @f() {
+        func.return
+      }
+    }
+  """,
+    ctx,
+    )
+    callback = lambda op: print(op.name)
+    # Test post-order walk (default).
+    # CHECK-NEXT:  Post-order
+    # CHECK-NEXT:  func.return
+    # CHECK-NEXT:  func.func
+    # CHECK-NEXT:  builtin.module
+    print("Post-order")
+    module.operation.walk(callback)
+
+    # Test pre-order walk.
+    # CHECK-NEXT:  Pre-order
+    # CHECK-NEXT:  builtin.module
+    # CHECK-NEXT:  func.fun
+    # CHECK-NEXT:  func.return
+    print("Pre-order")
+    module.operation.walk(callback, True)

@uenoku uenoku requested review from makslevental, ftynse and jpienaar and removed request for makslevental April 8, 2024 06:05
Copy link

github-actions bot commented Apr 8, 2024

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Apr 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@ftynse
Copy link
Member

ftynse commented Apr 8, 2024

I wonder if this is something that should be implemented in python instead. We have access to op nesting structure of regions/blocks, so it shouldn't be hard. And not jumping between Python and C++ callbacks will simplify debugging for users.

Copy link
Member Author

@uenoku uenoku left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback!

We have access to op nesting structure of regions/blocks, so it shouldn't be hard.

Yes exactly. We (SiFive) have implemented walk method in that way in our internal python tool for a while. However the overhead of python is expensive and it didn't scale well for large MLIR files witin SIFive. Just replacing walk with this version reduced the entire execution time of the tool by 30~40% and there are a few configs that the tool takes several hours to finish so we want to use the native walk method (though we should fix our tool algorithmically).

I agree that debugging will be harder so I'm fine if we cannot upstream this. In that case we'll probably implement it in CIRCT as a helper function.

@ftynse
Copy link
Member

ftynse commented Apr 8, 2024

Just replacing walk with this version reduced the entire execution time of the tool by 30~40%

This sounds fair enough as a justification. Please include it in the commit message.

@@ -1249,6 +1249,19 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}

void PyOperationBase::walk(py::object callback, bool usePreOrder) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible to spell out the callback type as std::function<void (MlirOperation)> here, not sure if this will have any performance implications.

Consider exposing MlirWalkResult in C and Python API to allow functions to stop the walk early.

I'm also not a fan of using a boolean kwarg for order, but can live with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for suggestion! I changed callback type to std::function<MlirWalkResult (MlirOperation)> as well as adding MlirWalkResult CAPI and python enums for MlirWalkResult/MlirWalkOrder.

# CHECK-NEXT: func.fun
# CHECK-NEXT: func.return
print("Pre-order")
module.operation.walk(callback, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test what happens if the callback raises an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for exception. It seems working as expected.

usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
: MlirWalkOrder::MlirWalkPostOrder);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like the c++ walk to evolve to support two callbacks to have both pre/post order visitation in the same walk.

This isn't a blocker for your change, but I'll need to break the C API, so just a heads up :)
(I suspect we can preserve the Python API backward compatible on top of it)

@makslevental
Copy link
Contributor

Just commenting to say thanks for adding this - I've been meaning to add this extension for a while and it got away from me.

@uenoku uenoku requested a review from ftynse April 11, 2024 12:55
@@ -18,6 +18,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H

#include <pybind11/functional.h>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to implicitly convert python object to std::function.

@uenoku uenoku force-pushed the dev/hidetou/python-walk branch from 96f76ac to dabfd78 Compare April 11, 2024 13:08
void *userData) {
auto *fn =
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
return (*fn)(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: can you actually check the result here to see if it's py::none() (i.e., the callback doesn't have a return statement) and then default to MlirWalkResult::MlirWalkResultAdvance. I think that's reasonable semantics? @ftynse

Copy link
Member Author

@uenoku uenoku Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it works since the function is type casted to std::function<MlirWalkResult(MlirOperation)>.
To allow void functions I think we could change the callback type to std::variant<std::function<MlirWalkResult(MlirOperation)>, std::function<void(MlirOperation)>> or maybe py::function (and dynamically type casts the returned type).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant is overkill for sure but py::function might work? anyway it's just a thought rather than a hard request so nbd if you don't want to bother.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents is if we can do this in a way that doesn't add much overhead it would be nice. But I'm also happy to see this go in with the user callback required to explicitly return a WalkResult... "explicit is better than implicit".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm too traumatized by the C++ version of omitting the return from a function and getting a random segfault, so I'm on the side of always having an explicit return. But not objecting strongly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm too traumatized by the C++ version of omitting the return from a function and getting a random segfault

I am similarly traumatized but in Python-land I think it's fairly well-known that a function with no return statement actually returns None.

@uenoku uenoku force-pushed the dev/hidetou/python-walk branch from dabfd78 to 9d1a282 Compare April 11, 2024 13:31
Copy link
Contributor

@mikeurbach mikeurbach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any major concern with the implementation, this looks good to me.

void *userData) {
auto *fn =
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
return (*fn)(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents is if we can do this in a way that doesn't add much overhead it would be nice. But I'm also happy to see this go in with the user callback required to explicitly return a WalkResult... "explicit is better than implicit".

@uenoku uenoku requested a review from stellaraccident as a code owner April 16, 2024 13:18
@uenoku
Copy link
Member Author

uenoku commented Apr 16, 2024

Thank you for all comments! I think the current implementation that requires explicit return results from python callback would be good start so I'd like to just merge the PR once CI passed. Let me open a tracking issue for allowing void function in the callback.

@uenoku uenoku merged commit 4714883 into llvm:main Apr 17, 2024
@uenoku uenoku deleted the dev/hidetou/python-walk branch April 17, 2024 06:09
@metaflow
Copy link
Contributor

Hi, we see a libc++abi: terminating due to uncaught exception of type pybind11::error_already_set: ValueError: <EMPTY MESSAGE> in the mlir/test/python/ir/operation.py

@uenoku
Copy link
Member Author

uenoku commented Apr 24, 2024

Hi, we see a libc++abi: terminating due to uncaught exception of type pybind11::error_already_set: ValueError: in the mlir/test/python/ir/operation.py

@metaflow Sorry I missed the message. It looks that the issue was fixed by #89225. Thank you @tomnatan30!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants