Skip to content

Commit 9d1a282

Browse files
committed
Add CAPI for WalkResult. Add python enum definition of WalkResult/WalkOrder
1 parent 218279d commit 9d1a282

File tree

7 files changed

+136
-24
lines changed

7 files changed

+136
-24
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
705705
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
706706
MlirOperation other);
707707

708+
/// Operation walk result.
709+
typedef enum MlirWalkResult {
710+
MlirWalkResultAdvance,
711+
MlirWalkResultInterrupt,
712+
MlirWalkResultSkip
713+
} MlirWalkResult;
714+
708715
/// Traversal order for operation walk.
709716
typedef enum MlirWalkOrder {
710717
MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
713720

714721
/// Operation walker type. The handler is passed an (opaque) reference to an
715722
/// operation and a pointer to a `userData`.
716-
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
723+
typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
724+
void *userData);
717725

718726
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
719727
/// `*userData` is passed to the callback as well and can be used to tunnel some

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
1919
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
2020

21+
#include <pybind11/functional.h>
2122
#include <pybind11/pybind11.h>
2223
#include <pybind11/pytypes.h>
2324
#include <pybind11/stl.h>

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
674674
data->rootOp.getOperation().getContext()->clearOperation(op);
675675
else
676676
data->rootSeen = true;
677+
return MlirWalkResult::MlirWalkResultAdvance;
677678
};
678679
mlirOperationWalk(op.getOperation(), invalidatingCallback,
679680
static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,17 +1250,19 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
12491250
.str());
12501251
}
12511252

1252-
void PyOperationBase::walk(py::object callback, bool usePreOrder) {
1253+
void PyOperationBase::walk(
1254+
std::function<MlirWalkResult(MlirOperation)> callback,
1255+
MlirWalkOrder walkOrder) {
12531256
PyOperation &operation = getOperation();
12541257
operation.checkValid();
12551258
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
12561259
void *userData) {
1257-
py::object *fn = static_cast<py::object *>(userData);
1258-
(*fn)(op);
1260+
auto *fn =
1261+
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
1262+
return (*fn)(op);
12591263
};
1260-
mlirOperationWalk(operation, walkCallback, &callback,
1261-
usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
1262-
: MlirWalkOrder::MlirWalkPostOrder);
1264+
1265+
mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
12631266
}
12641267

12651268
py::object PyOperationBase::getAsm(bool binary,
@@ -2524,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
25242527
.value("NOTE", MlirDiagnosticNote)
25252528
.value("REMARK", MlirDiagnosticRemark);
25262529

2530+
py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2531+
.value("PRE_ORDER", MlirWalkPreOrder)
2532+
.value("POST_ORDER", MlirWalkPostOrder);
2533+
2534+
py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2535+
.value("ADVANCE", MlirWalkResultAdvance)
2536+
.value("INTERRUPT", MlirWalkResultInterrupt)
2537+
.value("SKIP", MlirWalkResultSkip);
2538+
25272539
//----------------------------------------------------------------------------
25282540
// Mapping of Diagnostics.
25292541
//----------------------------------------------------------------------------
@@ -3052,7 +3064,7 @@ void mlir::python::populateIRCore(py::module &m) {
30523064
"Detaches the operation from its parent block.")
30533065
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
30543066
.def("walk", &PyOperationBase::walk, py::arg("callback"),
3055-
py::arg("use_pre_order") = py::bool_(false));
3067+
py::arg("walk_order") = MlirWalkPostOrder);
30563068

30573069
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
30583070
.def_static("create", &PyOperation::create, py::arg("name"),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,8 @@ class PyOperationBase {
580580
std::optional<int64_t> bytecodeVersion);
581581

582582
// Implement the walk method.
583-
void walk(pybind11::object callback, bool usePreOrder);
583+
void walk(std::function<MlirWalkResult(MlirOperation)> callback,
584+
MlirWalkOrder walkOrder);
584585

585586
/// Moves the operation before or after the other operation.
586587
void moveAfter(PyOperationBase &other);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
717717
return unwrap(op)->moveBefore(unwrap(other));
718718
}
719719

720+
static mlir::WalkResult translateWalkResult(MlirWalkResult result) {
721+
switch (result) {
722+
case MlirWalkResultAdvance:
723+
return mlir::WalkResult::advance();
724+
725+
case MlirWalkResultInterrupt:
726+
return mlir::WalkResult::interrupt();
727+
728+
case MlirWalkResultSkip:
729+
return mlir::WalkResult::skip();
730+
}
731+
}
732+
720733
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
721734
void *userData, MlirWalkOrder walkOrder) {
722735
switch (walkOrder) {
723736

724737
case MlirWalkPreOrder:
725738
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
726-
[callback, userData](Operation *op) { callback(wrap(op), userData); });
739+
[callback, userData](Operation *op) {
740+
return translateWalkResult(callback(wrap(op), userData));
741+
});
727742
break;
728743
case MlirWalkPostOrder:
729744
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
730-
[callback, userData](Operation *op) { callback(wrap(op), userData); });
745+
[callback, userData](Operation *op) {
746+
return translateWalkResult(callback(wrap(op), userData));
747+
});
731748
}
732749
}
733750

mlir/test/CAPI/ir.c

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,9 +2244,22 @@ typedef struct {
22442244
const char *x;
22452245
} callBackData;
22462246

2247-
void walkCallBack(MlirOperation op, void *rootOpVoid) {
2247+
MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
22482248
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
22492249
mlirIdentifierStr(mlirOperationGetName(op)).data);
2250+
return MlirWalkResultAdvance;
2251+
}
2252+
2253+
MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
2254+
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
2255+
mlirIdentifierStr(mlirOperationGetName(op)).data);
2256+
if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
2257+
0)
2258+
return MlirWalkResultSkip;
2259+
if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
2260+
0)
2261+
return MlirWalkResultInterrupt;
2262+
return MlirWalkResultAdvance;
22502263
}
22512264

22522265
int testOperationWalk(MlirContext ctx) {
@@ -2259,29 +2272,52 @@ int testOperationWalk(MlirContext ctx) {
22592272
" arith.addi %1, %1: i32\n"
22602273
" return\n"
22612274
" }\n"
2275+
" func.func @bar() {\n"
2276+
" return\n"
2277+
" }\n"
22622278
"}";
22632279
MlirModule module =
22642280
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
22652281

22662282
callBackData data;
22672283
data.x = "i love you";
22682284

2269-
// CHECK: i love you: arith.constant
2270-
// CHECK: i love you: arith.addi
2271-
// CHECK: i love you: func.return
2272-
// CHECK: i love you: func.func
2273-
// CHECK: i love you: builtin.module
2285+
// CHECK-NEXT: i love you: arith.constant
2286+
// CHECK-NEXT: i love you: arith.addi
2287+
// CHECK-NEXT: i love you: func.return
2288+
// CHECK-NEXT: i love you: func.func
2289+
// CHECK-NEXT: i love you: func.return
2290+
// CHECK-NEXT: i love you: func.func
2291+
// CHECK-NEXT: i love you: builtin.module
22742292
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
22752293
(void *)(&data), MlirWalkPostOrder);
22762294

22772295
data.x = "i don't love you";
2278-
// CHECK: i don't love you: builtin.module
2279-
// CHECK: i don't love you: func.func
2280-
// CHECK: i don't love you: arith.constant
2281-
// CHECK: i don't love you: arith.addi
2282-
// CHECK: i don't love you: func.return
2296+
// CHECK-NEXT: i don't love you: builtin.module
2297+
// CHECK-NEXT: i don't love you: func.func
2298+
// CHECK-NEXT: i don't love you: arith.constant
2299+
// CHECK-NEXT: i don't love you: arith.addi
2300+
// CHECK-NEXT: i don't love you: func.return
2301+
// CHECK-NEXT: i don't love you: func.func
2302+
// CHECK-NEXT: i don't love you: func.return
22832303
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
22842304
(void *)(&data), MlirWalkPreOrder);
2305+
2306+
data.x = "interrupt";
2307+
// Interrupted at `arith.addi`
2308+
// CHECK-NEXT: interrupt: arith.constant
2309+
// CHECK-NEXT: interrupt: arith.addi
2310+
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
2311+
(void *)(&data), MlirWalkPostOrder);
2312+
2313+
data.x = "skip";
2314+
// Skip at `func.func`
2315+
// CHECK-NEXT: skip: builtin.module
2316+
// CHECK-NEXT: skip: func.func
2317+
// CHECK-NEXT: skip: func.func
2318+
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
2319+
(void *)(&data), MlirWalkPreOrder);
2320+
22852321
mlirModuleDestroy(module);
22862322
return 0;
22872323
}

mlir/test/python/ir/operation.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,11 @@ def testOpWalk():
10321032
""",
10331033
ctx,
10341034
)
1035-
callback = lambda op: print(op.name)
1035+
1036+
def callback(op):
1037+
print(op.name)
1038+
return WalkResult.ADVANCE
1039+
10361040
# Test post-order walk (default).
10371041
# CHECK-NEXT: Post-order
10381042
# CHECK-NEXT: func.return
@@ -1047,4 +1051,37 @@ def testOpWalk():
10471051
# CHECK-NEXT: func.fun
10481052
# CHECK-NEXT: func.return
10491053
print("Pre-order")
1050-
module.operation.walk(callback, True)
1054+
module.operation.walk(callback, WalkOrder.PRE_ORDER)
1055+
1056+
# Test interrput.
1057+
# CHECK-NEXT: Interrupt post-order
1058+
# CHECK-NEXT: func.return
1059+
print("Interrupt post-order")
1060+
def callback(op):
1061+
print(op.name)
1062+
return WalkResult.INTERRUPT
1063+
module.operation.walk(callback)
1064+
1065+
# Test skip.
1066+
# CHECK-NEXT: Skip pre-order
1067+
# CHECK-NEXT: builtin.module
1068+
print("Skip pre-order")
1069+
def callback(op):
1070+
print(op.name)
1071+
return WalkResult.SKIP
1072+
module.operation.walk(callback, WalkOrder.PRE_ORDER)
1073+
1074+
# Test exception.
1075+
# CHECK: Exception
1076+
# CHECK-NEXT: func.return
1077+
# CHECK-NEXT: Exception raised
1078+
print("Exception")
1079+
def callback(op):
1080+
print(op.name)
1081+
raise ValueError
1082+
return WalkResult.ADVANCE
1083+
1084+
try:
1085+
module.operation.walk(callback)
1086+
except ValueError:
1087+
print("Exception raised")

0 commit comments

Comments
 (0)