Skip to content

Commit fd2f3c9

Browse files
maksleventalAnthony Tran
authored andcommitted
[mlir][python] bind block predecessors and successors (llvm#145116)
bind `block.getSuccessor` and `block.getPredecessors`.
1 parent 49392d3 commit fd2f3c9

File tree

5 files changed

+223
-4
lines changed

5 files changed

+223
-4
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,24 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
986986
MLIR_CAPI_EXPORTED void
987987
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
988988

989+
/// Returns the number of successor blocks of the block.
990+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
991+
992+
/// Returns `pos`-th successor of the block.
993+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
994+
intptr_t pos);
995+
996+
/// Returns the number of predecessor blocks of the block.
997+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
998+
999+
/// Returns `pos`-th predecessor of the block.
1000+
///
1001+
/// WARNING: This getter is more expensive than the others here because
1002+
/// the impl actually iterates the use-def chain (of block operands) anew for
1003+
/// each indexed access.
1004+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
1005+
intptr_t pos);
1006+
9891007
//===----------------------------------------------------------------------===//
9901008
// Value API.
9911009
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,88 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
26262626
PyOperationRef operation;
26272627
};
26282628

2629+
/// A list of block successors. Internally, these are stored as consecutive
2630+
/// elements, random access is cheap. The (returned) successor list is
2631+
/// associated with the operation and block whose successors these are, and thus
2632+
/// extends the lifetime of this operation and block.
2633+
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2634+
public:
2635+
static constexpr const char *pyClassName = "BlockSuccessors";
2636+
2637+
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2638+
intptr_t startIndex = 0, intptr_t length = -1,
2639+
intptr_t step = 1)
2640+
: Sliceable(startIndex,
2641+
length == -1 ? mlirBlockGetNumSuccessors(block.get())
2642+
: length,
2643+
step),
2644+
operation(operation), block(block) {}
2645+
2646+
private:
2647+
/// Give the parent CRTP class access to hook implementations below.
2648+
friend class Sliceable<PyBlockSuccessors, PyBlock>;
2649+
2650+
intptr_t getRawNumElements() {
2651+
block.checkValid();
2652+
return mlirBlockGetNumSuccessors(block.get());
2653+
}
2654+
2655+
PyBlock getRawElement(intptr_t pos) {
2656+
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2657+
return PyBlock(operation, block);
2658+
}
2659+
2660+
PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2661+
return PyBlockSuccessors(block, operation, startIndex, length, step);
2662+
}
2663+
2664+
PyOperationRef operation;
2665+
PyBlock block;
2666+
};
2667+
2668+
/// A list of block predecessors. The (returned) predecessor list is
2669+
/// associated with the operation and block whose predecessors these are, and
2670+
/// thus extends the lifetime of this operation and block.
2671+
///
2672+
/// WARNING: This Sliceable is more expensive than the others here because
2673+
/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
2674+
/// operands) anew for each indexed access.
2675+
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2676+
public:
2677+
static constexpr const char *pyClassName = "BlockPredecessors";
2678+
2679+
PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2680+
intptr_t startIndex = 0, intptr_t length = -1,
2681+
intptr_t step = 1)
2682+
: Sliceable(startIndex,
2683+
length == -1 ? mlirBlockGetNumPredecessors(block.get())
2684+
: length,
2685+
step),
2686+
operation(operation), block(block) {}
2687+
2688+
private:
2689+
/// Give the parent CRTP class access to hook implementations below.
2690+
friend class Sliceable<PyBlockPredecessors, PyBlock>;
2691+
2692+
intptr_t getRawNumElements() {
2693+
block.checkValid();
2694+
return mlirBlockGetNumPredecessors(block.get());
2695+
}
2696+
2697+
PyBlock getRawElement(intptr_t pos) {
2698+
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2699+
return PyBlock(operation, block);
2700+
}
2701+
2702+
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2703+
intptr_t step) {
2704+
return PyBlockPredecessors(block, operation, startIndex, length, step);
2705+
}
2706+
2707+
PyOperationRef operation;
2708+
PyBlock block;
2709+
};
2710+
26292711
/// A list of operation attributes. Can be indexed by name, producing
26302712
/// attributes, or by index, producing named attributes.
26312713
class PyOpAttributeMap {
@@ -3655,7 +3737,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553737
},
36563738
nb::arg("operation"),
36573739
"Appends an operation to this block. If the operation is currently "
3658-
"in another block, it will be moved.");
3740+
"in another block, it will be moved.")
3741+
.def_prop_ro(
3742+
"successors",
3743+
[](PyBlock &self) {
3744+
return PyBlockSuccessors(self, self.getParentOperation());
3745+
},
3746+
"Returns the list of Block successors.")
3747+
.def_prop_ro(
3748+
"predecessors",
3749+
[](PyBlock &self) {
3750+
return PyBlockPredecessors(self, self.getParentOperation());
3751+
},
3752+
"Returns the list of Block predecessors.");
36593753

36603754
//----------------------------------------------------------------------------
36613755
// Mapping of PyInsertionPoint.
@@ -4099,6 +4193,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
40994193
PyBlockArgumentList::bind(m);
41004194
PyBlockIterator::bind(m);
41014195
PyBlockList::bind(m);
4196+
PyBlockSuccessors::bind(m);
4197+
PyBlockPredecessors::bind(m);
41024198
PyOperationIterator::bind(m);
41034199
PyOperationList::bind(m);
41044200
PyOpAttributeMap::bind(m);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
10591059
unwrap(block)->print(stream);
10601060
}
10611061

1062+
intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
1063+
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
1064+
}
1065+
1066+
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
1067+
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
1068+
}
1069+
1070+
intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
1071+
Block *b = unwrap(block);
1072+
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
1073+
}
1074+
1075+
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
1076+
Block *b = unwrap(block);
1077+
Block::pred_iterator it = b->pred_begin();
1078+
std::advance(it, pos);
1079+
return wrap(*it);
1080+
}
1081+
10621082
//===----------------------------------------------------------------------===//
10631083
// Value API.
10641084
//===----------------------------------------------------------------------===//

mlir/test/CAPI/ir.c

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,6 +2440,74 @@ void testDiagnostics(void) {
24402440
mlirContextDestroy(ctx);
24412441
}
24422442

2443+
int testBlockPredecessorsSuccessors(MlirContext ctx) {
2444+
// CHECK-LABEL: @testBlockPredecessorsSuccessors
2445+
fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
2446+
2447+
const char *moduleString = "module {\n"
2448+
" func.func @test(%arg0: i32, %arg1: i16) {\n"
2449+
" cf.br ^bb1(%arg1 : i16)\n"
2450+
" ^bb1(%0: i16): // pred: ^bb0\n"
2451+
" cf.br ^bb2(%arg0 : i32)\n"
2452+
" ^bb2(%1: i32): // pred: ^bb1\n"
2453+
" return\n"
2454+
" }\n"
2455+
"}\n";
2456+
2457+
MlirModule module =
2458+
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
2459+
2460+
MlirOperation moduleOp = mlirModuleGetOperation(module);
2461+
MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0);
2462+
MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion);
2463+
MlirOperation function = mlirBlockGetFirstOperation(moduleBlock);
2464+
MlirRegion funcRegion = mlirOperationGetRegion(function, 0);
2465+
MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion);
2466+
MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock);
2467+
MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock);
2468+
2469+
#define FPRINTF_OP(OP, FMT) fprintf(stderr, #OP ": " FMT "\n", OP)
2470+
2471+
// CHECK: mlirBlockGetNumPredecessors(entryBlock): 0
2472+
FPRINTF_OP(mlirBlockGetNumPredecessors(entryBlock), "%ld");
2473+
2474+
// CHECK: mlirBlockGetNumSuccessors(entryBlock): 1
2475+
FPRINTF_OP(mlirBlockGetNumSuccessors(entryBlock), "%ld");
2476+
// CHECK: mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)): 1
2477+
FPRINTF_OP(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)),
2478+
"%d");
2479+
// CHECK: mlirBlockGetNumPredecessors(middleBlock): 1
2480+
FPRINTF_OP(mlirBlockGetNumPredecessors(middleBlock), "%ld");
2481+
// CHECK: mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0))
2482+
FPRINTF_OP(
2483+
mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0)),
2484+
"%d");
2485+
2486+
// CHECK: mlirBlockGetNumSuccessors(middleBlock): 1
2487+
FPRINTF_OP(mlirBlockGetNumSuccessors(middleBlock), "%ld");
2488+
// CHECK: BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): 1
2489+
fprintf(
2490+
stderr,
2491+
"BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): %d\n",
2492+
mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)));
2493+
// CHECK: mlirBlockGetNumPredecessors(successorBlock): 1
2494+
FPRINTF_OP(mlirBlockGetNumPredecessors(successorBlock), "%ld");
2495+
// CHECK: Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): 1
2496+
fprintf(
2497+
stderr,
2498+
"Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): %d\n",
2499+
mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)));
2500+
2501+
// CHECK: mlirBlockGetNumSuccessors(successorBlock): 0
2502+
FPRINTF_OP(mlirBlockGetNumSuccessors(successorBlock), "%ld");
2503+
2504+
#undef FPRINTF_OP
2505+
2506+
mlirModuleDestroy(module);
2507+
2508+
return 0;
2509+
}
2510+
24432511
int main(void) {
24442512
MlirContext ctx = mlirContextCreate();
24452513
registerAllUpstreamDialects(ctx);
@@ -2486,6 +2554,9 @@ int main(void) {
24862554
testExplicitThreadPools();
24872555
testDiagnostics();
24882556

2557+
if (testBlockPredecessorsSuccessors(ctx))
2558+
return 17;
2559+
24892560
// CHECK: DESTROY MAIN CONTEXT
24902561
// CHECK: reportResourceDelete: resource_i64_blob
24912562
fprintf(stderr, "DESTROY MAIN CONTEXT\n");

mlir/test/python/ir/blocks.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4-
import io
5-
import itertools
6-
from mlir.ir import *
4+
75
from mlir.dialects import builtin
86
from mlir.dialects import cf
97
from mlir.dialects import func
8+
from mlir.ir import *
109

1110

1211
def run(f):
@@ -54,10 +53,25 @@ def testBlockCreation():
5453
with InsertionPoint(middle_block) as middle_ip:
5554
assert middle_ip.block == middle_block
5655
cf.BranchOp([i32_arg], dest=successor_block)
56+
5757
module.print(enable_debug_info=True)
5858
# Ensure region back references are coherent.
5959
assert entry_block.region == middle_block.region == successor_block.region
6060

61+
assert len(entry_block.predecessors) == 0
62+
63+
assert len(entry_block.successors) == 1
64+
assert middle_block == entry_block.successors[0]
65+
assert len(middle_block.predecessors) == 1
66+
assert entry_block == middle_block.predecessors[0]
67+
68+
assert len(middle_block.successors) == 1
69+
assert successor_block == middle_block.successors[0]
70+
assert len(successor_block.predecessors) == 1
71+
assert middle_block == successor_block.predecessors[0]
72+
73+
assert len(successor_block.successors) == 0
74+
6175

6276
# CHECK-LABEL: TEST: testBlockCreationArgLocs
6377
@run

0 commit comments

Comments
 (0)