Skip to content

Commit 4714883

Browse files
authored
[mlir][python] Add walk method to PyOperationBase (#87962)
This commit adds `walk` method to PyOperationBase that uses a python object as a callback, e.g. `op.walk(callback)`. Currently callback must return a walk result explicitly. We(SiFive) have implemented walk method with python in our internal python tool for a while. However the overhead of python is expensive and it didn't scale well for large MLIR files. Just replacing walk with this version reduced the entire execution time of the tool by 30~40% and there are a few configs that the tool takes several hours to finish so this commit significantly improves tool performance.
1 parent b851c7f commit 4714883

File tree

7 files changed

+184
-17
lines changed

7 files changed

+184
-17
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
705705
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
706706
MlirOperation other);
707707

708+
/// Operation walk result.
709+
typedef enum MlirWalkResult {
710+
MlirWalkResultAdvance,
711+
MlirWalkResultInterrupt,
712+
MlirWalkResultSkip
713+
} MlirWalkResult;
714+
708715
/// Traversal order for operation walk.
709716
typedef enum MlirWalkOrder {
710717
MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
713720

714721
/// Operation walker type. The handler is passed an (opaque) reference to an
715722
/// operation and a pointer to a `userData`.
716-
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
723+
typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
724+
void *userData);
717725

718726
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
719727
/// `*userData` is passed to the callback as well and can be used to tunnel some

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
1919
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
2020

21+
#include <pybind11/functional.h>
2122
#include <pybind11/pybind11.h>
2223
#include <pybind11/pytypes.h>
2324
#include <pybind11/stl.h>

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
674674
data->rootOp.getOperation().getContext()->clearOperation(op);
675675
else
676676
data->rootSeen = true;
677+
return MlirWalkResult::MlirWalkResultAdvance;
677678
};
678679
mlirOperationWalk(op.getOperation(), invalidatingCallback,
679680
static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
12491250
.str());
12501251
}
12511252

1253+
void PyOperationBase::walk(
1254+
std::function<MlirWalkResult(MlirOperation)> callback,
1255+
MlirWalkOrder walkOrder) {
1256+
PyOperation &operation = getOperation();
1257+
operation.checkValid();
1258+
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1259+
void *userData) {
1260+
auto *fn =
1261+
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
1262+
return (*fn)(op);
1263+
};
1264+
1265+
mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
1266+
}
1267+
12521268
py::object PyOperationBase::getAsm(bool binary,
12531269
std::optional<int64_t> largeElementsLimit,
12541270
bool enableDebugInfo, bool prettyDebugInfo,
@@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
25112527
.value("NOTE", MlirDiagnosticNote)
25122528
.value("REMARK", MlirDiagnosticRemark);
25132529

2530+
py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2531+
.value("PRE_ORDER", MlirWalkPreOrder)
2532+
.value("POST_ORDER", MlirWalkPostOrder);
2533+
2534+
py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2535+
.value("ADVANCE", MlirWalkResultAdvance)
2536+
.value("INTERRUPT", MlirWalkResultInterrupt)
2537+
.value("SKIP", MlirWalkResultSkip);
2538+
25142539
//----------------------------------------------------------------------------
25152540
// Mapping of Diagnostics.
25162541
//----------------------------------------------------------------------------
@@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) {
29893014
py::arg("binary") = false, kOperationPrintStateDocstring)
29903015
.def("print",
29913016
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
2992-
bool, py::object, bool>(
2993-
&PyOperationBase::print),
3017+
bool, py::object, bool>(&PyOperationBase::print),
29943018
// Careful: Lots of arguments must match up with print method.
29953019
py::arg("large_elements_limit") = py::none(),
29963020
py::arg("enable_debug_info") = false,
@@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) {
30383062
return operation.createOpView();
30393063
},
30403064
"Detaches the operation from its parent block.")
3041-
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
3065+
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3066+
.def("walk", &PyOperationBase::walk, py::arg("callback"),
3067+
py::arg("walk_order") = MlirWalkPostOrder);
30423068

30433069
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
30443070
.def_static("create", &PyOperation::create, py::arg("name"),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,10 @@ class PyOperationBase {
579579
void writeBytecode(const pybind11::object &fileObject,
580580
std::optional<int64_t> bytecodeVersion);
581581

582+
// Implement the walk method.
583+
void walk(std::function<MlirWalkResult(MlirOperation)> callback,
584+
MlirWalkOrder walkOrder);
585+
582586
/// Moves the operation before or after the other operation.
583587
void moveAfter(PyOperationBase &other);
584588
void moveBefore(PyOperationBase &other);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
717717
return unwrap(op)->moveBefore(unwrap(other));
718718
}
719719

720+
static mlir::WalkResult unwrap(MlirWalkResult result) {
721+
switch (result) {
722+
case MlirWalkResultAdvance:
723+
return mlir::WalkResult::advance();
724+
725+
case MlirWalkResultInterrupt:
726+
return mlir::WalkResult::interrupt();
727+
728+
case MlirWalkResultSkip:
729+
return mlir::WalkResult::skip();
730+
}
731+
}
732+
720733
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
721734
void *userData, MlirWalkOrder walkOrder) {
722735
switch (walkOrder) {
723736

724737
case MlirWalkPreOrder:
725738
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
726-
[callback, userData](Operation *op) { callback(wrap(op), userData); });
739+
[callback, userData](Operation *op) {
740+
return unwrap(callback(wrap(op), userData));
741+
});
727742
break;
728743
case MlirWalkPostOrder:
729744
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
730-
[callback, userData](Operation *op) { callback(wrap(op), userData); });
745+
[callback, userData](Operation *op) {
746+
return unwrap(callback(wrap(op), userData));
747+
});
731748
}
732749
}
733750

mlir/test/CAPI/ir.c

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,9 +2244,22 @@ typedef struct {
22442244
const char *x;
22452245
} callBackData;
22462246

2247-
void walkCallBack(MlirOperation op, void *rootOpVoid) {
2247+
MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
22482248
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
22492249
mlirIdentifierStr(mlirOperationGetName(op)).data);
2250+
return MlirWalkResultAdvance;
2251+
}
2252+
2253+
MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
2254+
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
2255+
mlirIdentifierStr(mlirOperationGetName(op)).data);
2256+
if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
2257+
0)
2258+
return MlirWalkResultSkip;
2259+
if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
2260+
0)
2261+
return MlirWalkResultInterrupt;
2262+
return MlirWalkResultAdvance;
22502263
}
22512264

22522265
int testOperationWalk(MlirContext ctx) {
@@ -2259,29 +2272,52 @@ int testOperationWalk(MlirContext ctx) {
22592272
" arith.addi %1, %1: i32\n"
22602273
" return\n"
22612274
" }\n"
2275+
" func.func @bar() {\n"
2276+
" return\n"
2277+
" }\n"
22622278
"}";
22632279
MlirModule module =
22642280
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
22652281

22662282
callBackData data;
22672283
data.x = "i love you";
22682284

2269-
// CHECK: i love you: arith.constant
2270-
// CHECK: i love you: arith.addi
2271-
// CHECK: i love you: func.return
2272-
// CHECK: i love you: func.func
2273-
// CHECK: i love you: builtin.module
2285+
// CHECK-NEXT: i love you: arith.constant
2286+
// CHECK-NEXT: i love you: arith.addi
2287+
// CHECK-NEXT: i love you: func.return
2288+
// CHECK-NEXT: i love you: func.func
2289+
// CHECK-NEXT: i love you: func.return
2290+
// CHECK-NEXT: i love you: func.func
2291+
// CHECK-NEXT: i love you: builtin.module
22742292
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
22752293
(void *)(&data), MlirWalkPostOrder);
22762294

22772295
data.x = "i don't love you";
2278-
// CHECK: i don't love you: builtin.module
2279-
// CHECK: i don't love you: func.func
2280-
// CHECK: i don't love you: arith.constant
2281-
// CHECK: i don't love you: arith.addi
2282-
// CHECK: i don't love you: func.return
2296+
// CHECK-NEXT: i don't love you: builtin.module
2297+
// CHECK-NEXT: i don't love you: func.func
2298+
// CHECK-NEXT: i don't love you: arith.constant
2299+
// CHECK-NEXT: i don't love you: arith.addi
2300+
// CHECK-NEXT: i don't love you: func.return
2301+
// CHECK-NEXT: i don't love you: func.func
2302+
// CHECK-NEXT: i don't love you: func.return
22832303
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
22842304
(void *)(&data), MlirWalkPreOrder);
2305+
2306+
data.x = "interrupt";
2307+
// Interrupted at `arith.addi`
2308+
// CHECK-NEXT: interrupt: arith.constant
2309+
// CHECK-NEXT: interrupt: arith.addi
2310+
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
2311+
(void *)(&data), MlirWalkPostOrder);
2312+
2313+
data.x = "skip";
2314+
// Skip at `func.func`
2315+
// CHECK-NEXT: skip: builtin.module
2316+
// CHECK-NEXT: skip: func.func
2317+
// CHECK-NEXT: skip: func.func
2318+
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
2319+
(void *)(&data), MlirWalkPreOrder);
2320+
22852321
mlirModuleDestroy(module);
22862322
return 0;
22872323
}

mlir/test/python/ir/operation.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,78 @@ def testOperationParse():
10151015
print(
10161016
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
10171017
)
1018+
1019+
1020+
# CHECK-LABEL: TEST: testOpWalk
1021+
@run
1022+
def testOpWalk():
1023+
ctx = Context()
1024+
ctx.allow_unregistered_dialects = True
1025+
module = Module.parse(
1026+
r"""
1027+
builtin.module {
1028+
func.func @f() {
1029+
func.return
1030+
}
1031+
}
1032+
""",
1033+
ctx,
1034+
)
1035+
1036+
def callback(op):
1037+
print(op.name)
1038+
return WalkResult.ADVANCE
1039+
1040+
# Test post-order walk (default).
1041+
# CHECK-NEXT: Post-order
1042+
# CHECK-NEXT: func.return
1043+
# CHECK-NEXT: func.func
1044+
# CHECK-NEXT: builtin.module
1045+
print("Post-order")
1046+
module.operation.walk(callback)
1047+
1048+
# Test pre-order walk.
1049+
# CHECK-NEXT: Pre-order
1050+
# CHECK-NEXT: builtin.module
1051+
# CHECK-NEXT: func.fun
1052+
# CHECK-NEXT: func.return
1053+
print("Pre-order")
1054+
module.operation.walk(callback, WalkOrder.PRE_ORDER)
1055+
1056+
# Test interrput.
1057+
# CHECK-NEXT: Interrupt post-order
1058+
# CHECK-NEXT: func.return
1059+
print("Interrupt post-order")
1060+
1061+
def callback(op):
1062+
print(op.name)
1063+
return WalkResult.INTERRUPT
1064+
1065+
module.operation.walk(callback)
1066+
1067+
# Test skip.
1068+
# CHECK-NEXT: Skip pre-order
1069+
# CHECK-NEXT: builtin.module
1070+
print("Skip pre-order")
1071+
1072+
def callback(op):
1073+
print(op.name)
1074+
return WalkResult.SKIP
1075+
1076+
module.operation.walk(callback, WalkOrder.PRE_ORDER)
1077+
1078+
# Test exception.
1079+
# CHECK: Exception
1080+
# CHECK-NEXT: func.return
1081+
# CHECK-NEXT: Exception raised
1082+
print("Exception")
1083+
1084+
def callback(op):
1085+
print(op.name)
1086+
raise ValueError
1087+
return WalkResult.ADVANCE
1088+
1089+
try:
1090+
module.operation.walk(callback)
1091+
except ValueError:
1092+
print("Exception raised")

0 commit comments

Comments
 (0)