Skip to content

[mlir][python] Fix PyOperationBase::walk not catching exception in python callback #89225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2024

Conversation

tomnatan30
Copy link
Contributor

If the python callback throws an error, the c++ code will throw a py::error_already_set that needs to be caught and handled in the c++ code .

This change is inspired by the similar solution in PySymbolTable::walkSymbolTables.

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir

Author: None (tomnatan30)

Changes

If the python callback throws an error, the c++ code will throw a py::error_already_set that needs to be caught and handled in the c++ code .

This change is inspired by the similar solution in PySymbolTable::walkSymbolTables.


Full diff: https://github.com/llvm/llvm-project/pull/89225.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+22-5)
  • (modified) mlir/test/python/ir/operation.py (+1-1)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..0a12c53ac00abd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
     MlirWalkOrder walkOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
+  struct UserData {
+    std::function<MlirWalkResult(MlirOperation)> callback;
+    bool gotException;
+    std::string exceptionWhat;
+    py::object exceptionType;
+  };
+  UserData userData{.callback = callback};
   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
-    auto *fn =
-        static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
-    return (*fn)(op);
+    UserData *calleeUserData = static_cast<UserData *>(userData);
+    try {
+      return (calleeUserData->callback)(op);
+    } catch (py::error_already_set &e) {
+      calleeUserData->gotException = true;
+      calleeUserData->exceptionWhat = e.what();
+      calleeUserData->exceptionType = e.type();
+      return MlirWalkResult::MlirWalkResultInterrupt;
+    }
   };
-
-  mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+  if (userData.gotException) {
+    std::string message("Exception raised in callback: ");
+    message.append(userData.exceptionWhat);
+    throw std::runtime_error(message);
+  }
 }
 
 py::object PyOperationBase::getAsm(bool binary,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
 
     try:
         module.operation.walk(callback)
-    except ValueError:
+    except RuntimeError:
         print("Exception raised")

std::string exceptionWhat;
py::object exceptionType;
};
UserData userData{.callback = callback};
Copy link
Member

@ftynse ftynse Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a c++20 feature, LLVM is c++17. Please update to be c++17 compatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@tomnatan30 tomnatan30 force-pushed the piper_export_cl_625598483 branch from c3a6f7c to 8a2b373 Compare April 18, 2024 12:56
@makslevental
Copy link
Contributor

I suggest you land without a windows pass - windows CI is currently taking on the order of ~18 hours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants