@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
1255
1255
MlirWalkOrder walkOrder) {
1256
1256
PyOperation &operation = getOperation ();
1257
1257
operation.checkValid ();
1258
+ struct UserData {
1259
+ std::function<MlirWalkResult(MlirOperation)> callback;
1260
+ bool gotException;
1261
+ std::string exceptionWhat;
1262
+ py::object exceptionType;
1263
+ };
1264
+ UserData userData{.callback = callback};
1258
1265
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1259
1266
void *userData) {
1260
- auto *fn =
1261
- static_cast <std::function<MlirWalkResult (MlirOperation)> *>(userData);
1262
- return (*fn)(op);
1267
+ UserData *calleeUserData = static_cast <UserData *>(userData);
1268
+ try {
1269
+ return (calleeUserData->callback )(op);
1270
+ } catch (py::error_already_set &e) {
1271
+ calleeUserData->gotException = true ;
1272
+ calleeUserData->exceptionWhat = e.what ();
1273
+ calleeUserData->exceptionType = e.type ();
1274
+ return MlirWalkResult::MlirWalkResultInterrupt;
1275
+ }
1263
1276
};
1264
-
1265
- mlirOperationWalk (operation, walkCallback, &callback, walkOrder);
1277
+ mlirOperationWalk (operation, walkCallback, &userData, walkOrder);
1278
+ if (userData.gotException ) {
1279
+ std::string message (" Exception raised in callback: " );
1280
+ message.append (userData.exceptionWhat );
1281
+ throw std::runtime_error (message);
1282
+ }
1266
1283
}
1267
1284
1268
1285
py::object PyOperationBase::getAsm (bool binary,
0 commit comments