Skip to content

[mlir][CAPI, python bindings] Expose Operation::setSuccessor #67922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
intptr_t pos);

/// Set `pos`-th successor of the operation.
MLIR_CAPI_EXPORTED void
mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block);

/// Returns true if this operation defines an inherent attribute with this name.
/// Note: the attribute can be optional, so
/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
Expand Down
82 changes: 71 additions & 11 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2207,9 +2207,9 @@ class PyBlockArgumentList
};

/// A list of operation operands. Internally, these are stored as consecutive
/// elements, random access is cheap. The result list is associated with the
/// operation whose results these are, and extends the lifetime of this
/// operation.
/// elements, random access is cheap. The (returned) operand list is associated
/// with the operation whose operands these are, and thus extends the lifetime
/// of this operation.
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
Expand Down Expand Up @@ -2262,9 +2262,9 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
};

/// A list of operation results. Internally, these are stored as consecutive
/// elements, random access is cheap. The result list is associated with the
/// operation whose results these are, and extends the lifetime of this
/// operation.
/// elements, random access is cheap. The (returned) result list is associated
/// with the operation whose results these are, and thus extends the lifetime of
/// this operation.
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
Expand Down Expand Up @@ -2307,6 +2307,52 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
PyOperationRef operation;
};

/// A list of operation successors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation whose successors these are, and thus extends
/// the lifetime of this operation.
class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "OpSuccessors";

PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirOperationGetNumSuccessors(operation->get())
: length,
step),
operation(operation) {}

void dunderSetItem(intptr_t index, PyBlock block) {
index = wrapIndex(index);
mlirOperationSetSuccessor(operation->get(), index, block.get());
}

static void bindDerived(ClassTy &c) {
c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
}

private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpSuccessors, PyBlock>;

intptr_t getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumSuccessors(operation->get());
}

PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
return PyBlock(operation, block);
}

PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
return PyOpSuccessors(operation, startIndex, length, step);
}

PyOperationRef operation;
};

/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
Expand Down Expand Up @@ -2924,16 +2970,28 @@ void mlir::python::populateIRCore(py::module &m) {
&PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
.def_property_readonly("operation", [](py::object self) { return self; })
.def_property_readonly("opview", &PyOperation::createOpView);
.def_property_readonly("opview", &PyOperation::createOpView)
.def_property_readonly(
"successors",
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.");

auto opViewClass =
py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
.def(py::init<py::object>(), py::arg("operation"))
.def_property_readonly("operation", &PyOpView::getOperationObject)
.def_property_readonly("opview", [](py::object self) { return self; })
.def("__str__", [](PyOpView &self) {
return py::str(self.getOperationObject());
});
.def(
"__str__",
[](PyOpView &self) { return py::str(self.getOperationObject()); })
.def_property_readonly(
"successors",
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.");
opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
Expand Down Expand Up @@ -3448,7 +3506,8 @@ void mlir::python::populateIRCore(py::module &m) {
mlirOpPrintingFlagsUseLocalScope(flags);
valueState = mlirAsmStateCreateForValue(self.get(), flags);
}
mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
mlirValuePrintAsOperand(self.get(), valueState,
printAccum.getCallback(),
printAccum.getUserData());
// Release state if allocated locally.
if (!state) {
Expand Down Expand Up @@ -3523,6 +3582,7 @@ void mlir::python::populateIRCore(py::module &m) {
PyOpOperandIterator::bind(m);
PyOpOperandList::bind(m);
PyOpResultList::bind(m);
PyOpSuccessors::bind(m);
PyRegionIterator::bind(m);
PyRegionList::bind(m);

Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,11 @@ bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
}

void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos,
MlirBlock block) {
unwrap(op)->setSuccessor(unwrap(block), static_cast<unsigned>(pos));
}

intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
}
Expand Down
50 changes: 50 additions & 0 deletions mlir/test/python/dialects/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import cf


def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.create()
with InsertionPoint(module.body):
f()
return f


# CHECK-LABEL: TEST: testBranchAndSetSuccessor
@constructAndPrintInModule
def testBranchAndSetSuccessor():
op1 = Operation.create("custom.op1", regions=1)

block0 = op1.regions[0].blocks.append()
ip = InsertionPoint(block0)
Operation.create("custom.terminator", ip=ip)

block1 = op1.regions[0].blocks.append()
ip = InsertionPoint(block1)
br1 = cf.BranchOp([], block1, ip=ip)
# CHECK: ^bb1: // pred: ^bb1
# CHECK: cf.br ^bb1
print(br1.successors[0])
# CHECK: num_successors 1
print("num_successors", len(br1.successors))

block2 = op1.regions[0].blocks.append()
ip = InsertionPoint(block2)
br2 = cf.BranchOp([], block1, ip=ip)
# CHECK: ^bb1: // 2 preds: ^bb1, ^bb2
# CHECK: cf.br ^bb1
print(br2.successors[0])
# CHECK: num_successors 1
print("num_successors", len(br2.successors))

br1.successors[0] = block2
# CHECK: ^bb2: // pred: ^bb1
# CHECK: cf.br ^bb1
print(br1.successors[0])
# CHECK: ^bb1: // pred: ^bb2
# CHECK: cf.br ^bb2
print(br2.operation.successors[0])