Skip to content

Commit 77b9bbf

Browse files
committed
[mlir][python] Add walk method to PyOperationBase
This commit adds `walk` method that uses a python object as a callback
1 parent 6fd3677 commit 77b9bbf

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
12491249
.str());
12501250
}
12511251

1252+
void PyOperationBase::walk(py::object callback, bool usePreOrder) {
1253+
PyOperation &operation = getOperation();
1254+
operation.checkValid();
1255+
MlirOperationWalkCallback walkCallback =
1256+
[](MlirOperation op,
1257+
void *userData) {
1258+
py::object *fn = static_cast<py::object *>(userData);
1259+
(*fn)(op);
1260+
};
1261+
mlirOperationWalk(operation, walkCallback, &callback,
1262+
usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
1263+
: MlirWalkOrder::MlirWalkPostOrder);
1264+
}
1265+
12521266
py::object PyOperationBase::getAsm(bool binary,
12531267
std::optional<int64_t> largeElementsLimit,
12541268
bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
30383052
return operation.createOpView();
30393053
},
30403054
"Detaches the operation from its parent block.")
3041-
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
3055+
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3056+
.def("walk", &PyOperationBase::walk, py::arg("callback"),
3057+
py::arg("use_pre_order") = py::bool_(false));
30423058

30433059
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
30443060
.def_static("create", &PyOperation::create, py::arg("name"),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,9 @@ class PyOperationBase {
579579
void writeBytecode(const pybind11::object &fileObject,
580580
std::optional<int64_t> bytecodeVersion);
581581

582+
// Implement the walk method.
583+
void walk(pybind11::object callback, bool usePreOrder);
584+
582585
/// Moves the operation before or after the other operation.
583586
void moveAfter(PyOperationBase &other);
584587
void moveBefore(PyOperationBase &other);

mlir/test/python/ir/operation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,35 @@ def testOperationParse():
10151015
print(
10161016
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
10171017
)
1018+
1019+
# CHECK-LABEL: TEST: testOpWalk
1020+
@run
1021+
def testOpWalk():
1022+
ctx = Context()
1023+
ctx.allow_unregistered_dialects = True
1024+
module = Module.parse(
1025+
r"""
1026+
builtin.module {
1027+
func.func @f() {
1028+
func.return
1029+
}
1030+
}
1031+
""",
1032+
ctx,
1033+
)
1034+
callback = lambda op: print(op.name)
1035+
# Test post-order walk (default).
1036+
# CHECK-NEXT: Post-order
1037+
# CHECK-NEXT: func.return
1038+
# CHECK-NEXT: func.func
1039+
# CHECK-NEXT: builtin.module
1040+
print("Post-order")
1041+
module.operation.walk(callback)
1042+
1043+
# Test pre-order walk.
1044+
# CHECK-NEXT: Pre-order
1045+
# CHECK-NEXT: builtin.module
1046+
# CHECK-NEXT: func.fun
1047+
# CHECK-NEXT: func.return
1048+
print("Pre-order")
1049+
module.operation.walk(callback, True)

0 commit comments

Comments
 (0)