Skip to content

[mlir][python] Add walk method to PyOperationBase #87962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);

/// Operation walk result.
typedef enum MlirWalkResult {
MlirWalkResultAdvance,
MlirWalkResultInterrupt,
MlirWalkResultSkip
} MlirWalkResult;

/// Traversal order for operation walk.
typedef enum MlirWalkOrder {
MlirWalkPreOrder,
Expand All @@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {

/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation and a pointer to a `userData`.
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
void *userData);

/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H

#include <pybind11/functional.h>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to implicitly convert python object to std::function.

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
Expand Down
32 changes: 29 additions & 3 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
Expand Down Expand Up @@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}

void PyOperationBase::walk(
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
auto *fn =
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
return (*fn)(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: can you actually check the result here to see if it's py::none() (i.e., the callback doesn't have a return statement) and then default to MlirWalkResult::MlirWalkResultAdvance. I think that's reasonable semantics? @ftynse

Copy link
Member Author

@uenoku uenoku Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it works since the function is type casted to std::function<MlirWalkResult(MlirOperation)>.
To allow void functions I think we could change the callback type to std::variant<std::function<MlirWalkResult(MlirOperation)>, std::function<void(MlirOperation)>> or maybe py::function (and dynamically type casts the returned type).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant is overkill for sure but py::function might work? anyway it's just a thought rather than a hard request so nbd if you don't want to bother.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents is if we can do this in a way that doesn't add much overhead it would be nice. But I'm also happy to see this go in with the user callback required to explicitly return a WalkResult... "explicit is better than implicit".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm too traumatized by the C++ version of omitting the return from a function and getting a random segfault, so I'm on the side of always having an explicit return. But not objecting strongly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm too traumatized by the C++ version of omitting the return from a function and getting a random segfault

I am similarly traumatized but in Python-land I think it's fairly well-known that a function with no return statement actually returns None.

};

mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like the c++ walk to evolve to support two callbacks to have both pre/post order visitation in the same walk.

This isn't a blocker for your change, but I'll need to break the C API, so just a heads up :)
(I suspect we can preserve the Python API backward compatible on top of it)

py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
Expand Down Expand Up @@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);

py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
.value("PRE_ORDER", MlirWalkPreOrder)
.value("POST_ORDER", MlirWalkPostOrder);

py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
.value("ADVANCE", MlirWalkResultAdvance)
.value("INTERRUPT", MlirWalkResultInterrupt)
.value("SKIP", MlirWalkResultSkip);

//----------------------------------------------------------------------------
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
bool, py::object, bool>(
&PyOperationBase::print),
bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
Expand Down Expand Up @@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
.def("walk", &PyOperationBase::walk, py::arg("callback"),
py::arg("walk_order") = MlirWalkPostOrder);

py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ class PyOperationBase {
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);

// Implement the walk method.
void walk(std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder);

/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
Expand Down
21 changes: 19 additions & 2 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}

static mlir::WalkResult unwrap(MlirWalkResult result) {
switch (result) {
case MlirWalkResultAdvance:
return mlir::WalkResult::advance();

case MlirWalkResultInterrupt:
return mlir::WalkResult::interrupt();

case MlirWalkResultSkip:
return mlir::WalkResult::skip();
}
}

void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder) {
switch (walkOrder) {

case MlirWalkPreOrder:
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
[callback, userData](Operation *op) { callback(wrap(op), userData); });
[callback, userData](Operation *op) {
return unwrap(callback(wrap(op), userData));
});
break;
case MlirWalkPostOrder:
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
[callback, userData](Operation *op) { callback(wrap(op), userData); });
[callback, userData](Operation *op) {
return unwrap(callback(wrap(op), userData));
});
}
}

Expand Down
58 changes: 47 additions & 11 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -2244,9 +2244,22 @@ typedef struct {
const char *x;
} callBackData;

void walkCallBack(MlirOperation op, void *rootOpVoid) {
MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
mlirIdentifierStr(mlirOperationGetName(op)).data);
return MlirWalkResultAdvance;
}

MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
mlirIdentifierStr(mlirOperationGetName(op)).data);
if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
0)
return MlirWalkResultSkip;
if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
0)
return MlirWalkResultInterrupt;
return MlirWalkResultAdvance;
}

int testOperationWalk(MlirContext ctx) {
Expand All @@ -2259,29 +2272,52 @@ int testOperationWalk(MlirContext ctx) {
" arith.addi %1, %1: i32\n"
" return\n"
" }\n"
" func.func @bar() {\n"
" return\n"
" }\n"
"}";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));

callBackData data;
data.x = "i love you";

// CHECK: i love you: arith.constant
// CHECK: i love you: arith.addi
// CHECK: i love you: func.return
// CHECK: i love you: func.func
// CHECK: i love you: builtin.module
// CHECK-NEXT: i love you: arith.constant
// CHECK-NEXT: i love you: arith.addi
// CHECK-NEXT: i love you: func.return
// CHECK-NEXT: i love you: func.func
// CHECK-NEXT: i love you: func.return
// CHECK-NEXT: i love you: func.func
// CHECK-NEXT: i love you: builtin.module
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPostOrder);

data.x = "i don't love you";
// CHECK: i don't love you: builtin.module
// CHECK: i don't love you: func.func
// CHECK: i don't love you: arith.constant
// CHECK: i don't love you: arith.addi
// CHECK: i don't love you: func.return
// CHECK-NEXT: i don't love you: builtin.module
// CHECK-NEXT: i don't love you: func.func
// CHECK-NEXT: i don't love you: arith.constant
// CHECK-NEXT: i don't love you: arith.addi
// CHECK-NEXT: i don't love you: func.return
// CHECK-NEXT: i don't love you: func.func
// CHECK-NEXT: i don't love you: func.return
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPreOrder);

data.x = "interrupt";
// Interrupted at `arith.addi`
// CHECK-NEXT: interrupt: arith.constant
// CHECK-NEXT: interrupt: arith.addi
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
(void *)(&data), MlirWalkPostOrder);

data.x = "skip";
// Skip at `func.func`
// CHECK-NEXT: skip: builtin.module
// CHECK-NEXT: skip: func.func
// CHECK-NEXT: skip: func.func
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
(void *)(&data), MlirWalkPreOrder);

mlirModuleDestroy(module);
return 0;
}
Expand Down
75 changes: 75 additions & 0 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,78 @@ def testOperationParse():
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)


# CHECK-LABEL: TEST: testOpWalk
@run
def testOpWalk():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)

def callback(op):
print(op.name)
return WalkResult.ADVANCE

# Test post-order walk (default).
# CHECK-NEXT: Post-order
# CHECK-NEXT: func.return
# CHECK-NEXT: func.func
# CHECK-NEXT: builtin.module
print("Post-order")
module.operation.walk(callback)

# Test pre-order walk.
# CHECK-NEXT: Pre-order
# CHECK-NEXT: builtin.module
# CHECK-NEXT: func.fun
# CHECK-NEXT: func.return
print("Pre-order")
module.operation.walk(callback, WalkOrder.PRE_ORDER)

# Test interrput.
# CHECK-NEXT: Interrupt post-order
# CHECK-NEXT: func.return
print("Interrupt post-order")

def callback(op):
print(op.name)
return WalkResult.INTERRUPT

module.operation.walk(callback)

# Test skip.
# CHECK-NEXT: Skip pre-order
# CHECK-NEXT: builtin.module
print("Skip pre-order")

def callback(op):
print(op.name)
return WalkResult.SKIP

module.operation.walk(callback, WalkOrder.PRE_ORDER)

# Test exception.
# CHECK: Exception
# CHECK-NEXT: func.return
# CHECK-NEXT: Exception raised
print("Exception")

def callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE

try:
module.operation.walk(callback)
except ValueError:
print("Exception raised")