Skip to content

Commit bc55364

Browse files
authored
[mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225)
If the python callback throws an error, the c++ code will throw a py::error_already_set that needs to be caught and handled in the c++ code . This change is inspired by the similar solution in PySymbolTable::walkSymbolTables.
1 parent 8f07a67 commit bc55364

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
12551255
MlirWalkOrder walkOrder) {
12561256
PyOperation &operation = getOperation();
12571257
operation.checkValid();
1258+
struct UserData {
1259+
std::function<MlirWalkResult(MlirOperation)> callback;
1260+
bool gotException;
1261+
std::string exceptionWhat;
1262+
py::object exceptionType;
1263+
};
1264+
UserData userData{callback, false, {}, {}};
12581265
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
12591266
void *userData) {
1260-
auto *fn =
1261-
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
1262-
return (*fn)(op);
1267+
UserData *calleeUserData = static_cast<UserData *>(userData);
1268+
try {
1269+
return (calleeUserData->callback)(op);
1270+
} catch (py::error_already_set &e) {
1271+
calleeUserData->gotException = true;
1272+
calleeUserData->exceptionWhat = e.what();
1273+
calleeUserData->exceptionType = e.type();
1274+
return MlirWalkResult::MlirWalkResultInterrupt;
1275+
}
12631276
};
1264-
1265-
mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
1277+
mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1278+
if (userData.gotException) {
1279+
std::string message("Exception raised in callback: ");
1280+
message.append(userData.exceptionWhat);
1281+
throw std::runtime_error(message);
1282+
}
12661283
}
12671284

12681285
py::object PyOperationBase::getAsm(bool binary,

mlir/test/python/ir/operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,5 +1088,5 @@ def callback(op):
10881088

10891089
try:
10901090
module.operation.walk(callback)
1091-
except ValueError:
1091+
except RuntimeError:
10921092
print("Exception raised")

0 commit comments

Comments
 (0)