Skip to content

Commit d7e4973

Browse files
authored
[mlir][CAPI, python bindings] Expose Operation::setSuccessor (llvm#67922)
This is useful for emitting (using the python bindings) `cf.br` to blocks that are declared lexically post block creation.
1 parent 811b05c commit d7e4973

File tree

4 files changed

+130
-11
lines changed

4 files changed

+130
-11
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
576576
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
577577
intptr_t pos);
578578

579+
/// Set `pos`-th successor of the operation.
580+
MLIR_CAPI_EXPORTED void
581+
mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block);
582+
579583
/// Returns true if this operation defines an inherent attribute with this name.
580584
/// Note: the attribute can be optional, so
581585
/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,9 +2207,9 @@ class PyBlockArgumentList
22072207
};
22082208

22092209
/// A list of operation operands. Internally, these are stored as consecutive
2210-
/// elements, random access is cheap. The result list is associated with the
2211-
/// operation whose results these are, and extends the lifetime of this
2212-
/// operation.
2210+
/// elements, random access is cheap. The (returned) operand list is associated
2211+
/// with the operation whose operands these are, and thus extends the lifetime
2212+
/// of this operation.
22132213
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
22142214
public:
22152215
static constexpr const char *pyClassName = "OpOperandList";
@@ -2262,9 +2262,9 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
22622262
};
22632263

22642264
/// A list of operation results. Internally, these are stored as consecutive
2265-
/// elements, random access is cheap. The result list is associated with the
2266-
/// operation whose results these are, and extends the lifetime of this
2267-
/// operation.
2265+
/// elements, random access is cheap. The (returned) result list is associated
2266+
/// with the operation whose results these are, and thus extends the lifetime of
2267+
/// this operation.
22682268
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
22692269
public:
22702270
static constexpr const char *pyClassName = "OpResultList";
@@ -2307,6 +2307,52 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
23072307
PyOperationRef operation;
23082308
};
23092309

2310+
/// A list of operation successors. Internally, these are stored as consecutive
2311+
/// elements, random access is cheap. The (returned) successor list is
2312+
/// associated with the operation whose successors these are, and thus extends
2313+
/// the lifetime of this operation.
2314+
class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2315+
public:
2316+
static constexpr const char *pyClassName = "OpSuccessors";
2317+
2318+
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2319+
intptr_t length = -1, intptr_t step = 1)
2320+
: Sliceable(startIndex,
2321+
length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2322+
: length,
2323+
step),
2324+
operation(operation) {}
2325+
2326+
void dunderSetItem(intptr_t index, PyBlock block) {
2327+
index = wrapIndex(index);
2328+
mlirOperationSetSuccessor(operation->get(), index, block.get());
2329+
}
2330+
2331+
static void bindDerived(ClassTy &c) {
2332+
c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2333+
}
2334+
2335+
private:
2336+
/// Give the parent CRTP class access to hook implementations below.
2337+
friend class Sliceable<PyOpSuccessors, PyBlock>;
2338+
2339+
intptr_t getRawNumElements() {
2340+
operation->checkValid();
2341+
return mlirOperationGetNumSuccessors(operation->get());
2342+
}
2343+
2344+
PyBlock getRawElement(intptr_t pos) {
2345+
MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2346+
return PyBlock(operation, block);
2347+
}
2348+
2349+
PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2350+
return PyOpSuccessors(operation, startIndex, length, step);
2351+
}
2352+
2353+
PyOperationRef operation;
2354+
};
2355+
23102356
/// A list of operation attributes. Can be indexed by name, producing
23112357
/// attributes, or by index, producing named attributes.
23122358
class PyOpAttributeMap {
@@ -2924,16 +2970,28 @@ void mlir::python::populateIRCore(py::module &m) {
29242970
&PyOperation::getCapsule)
29252971
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
29262972
.def_property_readonly("operation", [](py::object self) { return self; })
2927-
.def_property_readonly("opview", &PyOperation::createOpView);
2973+
.def_property_readonly("opview", &PyOperation::createOpView)
2974+
.def_property_readonly(
2975+
"successors",
2976+
[](PyOperationBase &self) {
2977+
return PyOpSuccessors(self.getOperation().getRef());
2978+
},
2979+
"Returns the list of Operation successors.");
29282980

29292981
auto opViewClass =
29302982
py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
29312983
.def(py::init<py::object>(), py::arg("operation"))
29322984
.def_property_readonly("operation", &PyOpView::getOperationObject)
29332985
.def_property_readonly("opview", [](py::object self) { return self; })
2934-
.def("__str__", [](PyOpView &self) {
2935-
return py::str(self.getOperationObject());
2936-
});
2986+
.def(
2987+
"__str__",
2988+
[](PyOpView &self) { return py::str(self.getOperationObject()); })
2989+
.def_property_readonly(
2990+
"successors",
2991+
[](PyOperationBase &self) {
2992+
return PyOpSuccessors(self.getOperation().getRef());
2993+
},
2994+
"Returns the list of Operation successors.");
29372995
opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
29382996
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
29392997
opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
@@ -3448,7 +3506,8 @@ void mlir::python::populateIRCore(py::module &m) {
34483506
mlirOpPrintingFlagsUseLocalScope(flags);
34493507
valueState = mlirAsmStateCreateForValue(self.get(), flags);
34503508
}
3451-
mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
3509+
mlirValuePrintAsOperand(self.get(), valueState,
3510+
printAccum.getCallback(),
34523511
printAccum.getUserData());
34533512
// Release state if allocated locally.
34543513
if (!state) {
@@ -3523,6 +3582,7 @@ void mlir::python::populateIRCore(py::module &m) {
35233582
PyOpOperandIterator::bind(m);
35243583
PyOpOperandList::bind(m);
35253584
PyOpResultList::bind(m);
3585+
PyOpSuccessors::bind(m);
35263586
PyRegionIterator::bind(m);
35273587
PyRegionList::bind(m);
35283588

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,11 @@ bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
637637
return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
638638
}
639639

640+
void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos,
641+
MlirBlock block) {
642+
unwrap(op)->setSuccessor(unwrap(block), static_cast<unsigned>(pos));
643+
}
644+
640645
intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
641646
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
642647
}

mlir/test/python/dialects/cf.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import cf
5+
6+
7+
def constructAndPrintInModule(f):
8+
print("\nTEST:", f.__name__)
9+
with Context() as ctx, Location.unknown():
10+
ctx.allow_unregistered_dialects = True
11+
module = Module.create()
12+
with InsertionPoint(module.body):
13+
f()
14+
return f
15+
16+
17+
# CHECK-LABEL: TEST: testBranchAndSetSuccessor
18+
@constructAndPrintInModule
19+
def testBranchAndSetSuccessor():
20+
op1 = Operation.create("custom.op1", regions=1)
21+
22+
block0 = op1.regions[0].blocks.append()
23+
ip = InsertionPoint(block0)
24+
Operation.create("custom.terminator", ip=ip)
25+
26+
block1 = op1.regions[0].blocks.append()
27+
ip = InsertionPoint(block1)
28+
br1 = cf.BranchOp([], block1, ip=ip)
29+
# CHECK: ^bb1: // pred: ^bb1
30+
# CHECK: cf.br ^bb1
31+
print(br1.successors[0])
32+
# CHECK: num_successors 1
33+
print("num_successors", len(br1.successors))
34+
35+
block2 = op1.regions[0].blocks.append()
36+
ip = InsertionPoint(block2)
37+
br2 = cf.BranchOp([], block1, ip=ip)
38+
# CHECK: ^bb1: // 2 preds: ^bb1, ^bb2
39+
# CHECK: cf.br ^bb1
40+
print(br2.successors[0])
41+
# CHECK: num_successors 1
42+
print("num_successors", len(br2.successors))
43+
44+
br1.successors[0] = block2
45+
# CHECK: ^bb2: // pred: ^bb1
46+
# CHECK: cf.br ^bb1
47+
print(br1.successors[0])
48+
# CHECK: ^bb1: // pred: ^bb2
49+
# CHECK: cf.br ^bb2
50+
print(br2.operation.successors[0])

0 commit comments

Comments
 (0)