Skip to content

[mlir][python] bind block predecessors and successors #145116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,24 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
MLIR_CAPI_EXPORTED void
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);

/// Returns the number of successor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have tests for the C API as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test


/// Returns `pos`-th successor of the block.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
intptr_t pos);

/// Returns the number of predecessor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);

/// Returns `pos`-th predecessor of the block.
///
/// WARNING: This getter is more expensive than the others here because
/// the impl actually iterates the use-def chain (of block operands) anew for
/// each indexed access.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
intptr_t pos);

//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
Expand Down
98 changes: 97 additions & 1 deletion mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2626,6 +2626,88 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
PyOperationRef operation;
};

/// A list of block successors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation and block whose successors these are, and thus
/// extends the lifetime of this operation and block.
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockSuccessors";

PyBlockSuccessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirBlockGetNumSuccessors(block.get())
: length,
step),
operation(operation), block(block) {}

private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockSuccessors, PyBlock>;

intptr_t getRawNumElements() {
block.checkValid();
return mlirBlockGetNumSuccessors(block.get());
}

PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
return PyBlock(operation, block);
}

PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
return PyBlockSuccessors(block, operation, startIndex, length, step);
}

PyOperationRef operation;
PyBlock block;
};

/// A list of block predecessors. The (returned) predecessor list is
/// associated with the operation and block whose predecessors these are, and
/// thus extends the lifetime of this operation and block.
///
/// WARNING: This Sliceable is more expensive than the others here because
/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
/// operands) anew for each indexed access.
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockPredecessors";

PyBlockPredecessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirBlockGetNumPredecessors(block.get())
: length,
step),
operation(operation), block(block) {}

private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockPredecessors, PyBlock>;

intptr_t getRawNumElements() {
block.checkValid();
return mlirBlockGetNumPredecessors(block.get());
}

PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
return PyBlock(operation, block);
}

PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
intptr_t step) {
return PyBlockPredecessors(block, operation, startIndex, length, step);
}

PyOperationRef operation;
PyBlock block;
};

/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
Expand Down Expand Up @@ -3655,7 +3737,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("operation"),
"Appends an operation to this block. If the operation is currently "
"in another block, it will be moved.");
"in another block, it will be moved.")
.def_prop_ro(
"successors",
[](PyBlock &self) {
return PyBlockSuccessors(self, self.getParentOperation());
},
"Returns the list of Block successors.")
.def_prop_ro(
"predecessors",
[](PyBlock &self) {
return PyBlockPredecessors(self, self.getParentOperation());
},
"Returns the list of Block predecessors.");

//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
Expand Down Expand Up @@ -4099,6 +4193,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
PyBlockList::bind(m);
PyBlockSuccessors::bind(m);
PyBlockPredecessors::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpAttributeMap::bind(m);
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
unwrap(block)->print(stream);
}

intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
}

MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
}

intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
Block *b = unwrap(block);
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
}

MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
Block *b = unwrap(block);
Block::pred_iterator it = b->pred_begin();
std::advance(it, pos);
return wrap(*it);
Comment on lines +1077 to +1079
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather avoid iterating over the use-def list every time... This goes through block's use-def chain, maybe there is a way to expose a BlockOperand (and incidentally OpOperand if it isn't) and a getNextUse.

Copy link
Contributor Author

@makslevental makslevental Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think so

/// Implement a predecessor iterator for blocks. This works by walking the use

Compare with SuccessorRange just below there. But maybe I'm wrong and it's just not clicking for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just to add a litlle more "context"; if you look at getSuccessors and getSuccessor(unsigned) they work differently but still in a way that doesn't seem possible for getPredecessors:

SuccessorRange getSuccessors() { return SuccessorRange(this); }

...

SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
  if (block->empty() || llvm::hasSingleElement(*block->getParent()))
    return;
  Operation *term = &block->back();
  if ((count = term->getNumSuccessors()))
    base = term->getBlockOperands().data();
}

and

Block *Block::getSuccessor(unsigned i) {
  assert(i < getNumSuccessors());
  return getTerminator()->getSuccessor(i);
}

...

class Operation {
  ...
  Block *getSuccessor(unsigned index) {
    assert(index < getNumSuccessors());
    return getBlockOperands()[index].get();
  }
  ...
}

compared with

using pred_iterator = PredecessorIterator;
pred_iterator pred_begin() {
  return pred_iterator((BlockOperand *)getFirstUse());
}
pred_iterator pred_end() { return pred_iterator(nullptr); }
iterator_range<pred_iterator> getPredecessors() {
  return {pred_begin(), pred_end()};
}

so while I agree that iterating the chain isn't great I don't see what else can be done (other than caching those predecessors, which I'm sure we don't want to do either).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define predecessor as returning an iterable object that only has __next__?

Not a big problem if we don't, we use indexed accessors for the linked list of blocks as well because I thought it whoever was using Python didn't care about that level of performance tweaking.

Copy link
Contributor Author

@makslevental makslevental Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is there's no way to implement GetNextPredecessor like GetNextBlockInRegion

https://github.com/llvm/llvm-project/blob/main/mlir/lib/CAPI/IR/IR.cpp#L969

without holding an instance of PredecessorIterator (even forgetting that it needs to be mapped into C).

Anyway ya I'm gonna leave this as is but I'll add a comment mentioning that it's expensive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added warning

}

//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -2440,6 +2440,74 @@ void testDiagnostics(void) {
mlirContextDestroy(ctx);
}

int testBlockPredecessorsSuccessors(MlirContext ctx) {
// CHECK-LABEL: @testBlockPredecessorsSuccessors
fprintf(stderr, "@testBlockPredecessorsSuccessors\n");

const char *moduleString = "module {\n"
" func.func @test(%arg0: i32, %arg1: i16) {\n"
" cf.br ^bb1(%arg1 : i16)\n"
" ^bb1(%0: i16): // pred: ^bb0\n"
" cf.br ^bb2(%arg0 : i32)\n"
" ^bb2(%1: i32): // pred: ^bb1\n"
" return\n"
" }\n"
"}\n";

MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));

MlirOperation moduleOp = mlirModuleGetOperation(module);
MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0);
MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion);
MlirOperation function = mlirBlockGetFirstOperation(moduleBlock);
MlirRegion funcRegion = mlirOperationGetRegion(function, 0);
MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion);
MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock);
MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock);

#define FPRINTF_OP(OP, FMT) fprintf(stderr, #OP ": " FMT "\n", OP)

// CHECK: mlirBlockGetNumPredecessors(entryBlock): 0
FPRINTF_OP(mlirBlockGetNumPredecessors(entryBlock), "%ld");

// CHECK: mlirBlockGetNumSuccessors(entryBlock): 1
FPRINTF_OP(mlirBlockGetNumSuccessors(entryBlock), "%ld");
// CHECK: mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)): 1
FPRINTF_OP(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)),
"%d");
// CHECK: mlirBlockGetNumPredecessors(middleBlock): 1
FPRINTF_OP(mlirBlockGetNumPredecessors(middleBlock), "%ld");
// CHECK: mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0))
FPRINTF_OP(
mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0)),
"%d");

// CHECK: mlirBlockGetNumSuccessors(middleBlock): 1
FPRINTF_OP(mlirBlockGetNumSuccessors(middleBlock), "%ld");
// CHECK: BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): 1
fprintf(
stderr,
"BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): %d\n",
mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)));
// CHECK: mlirBlockGetNumPredecessors(successorBlock): 1
FPRINTF_OP(mlirBlockGetNumPredecessors(successorBlock), "%ld");
// CHECK: Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): 1
fprintf(
stderr,
"Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): %d\n",
mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)));

// CHECK: mlirBlockGetNumSuccessors(successorBlock): 0
FPRINTF_OP(mlirBlockGetNumSuccessors(successorBlock), "%ld");

#undef FPRINTF_OP

mlirModuleDestroy(module);

return 0;
}

int main(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
Expand Down Expand Up @@ -2486,6 +2554,9 @@ int main(void) {
testExplicitThreadPools();
testDiagnostics();

if (testBlockPredecessorsSuccessors(ctx))
return 17;

// CHECK: DESTROY MAIN CONTEXT
// CHECK: reportResourceDelete: resource_i64_blob
fprintf(stderr, "DESTROY MAIN CONTEXT\n");
Expand Down
20 changes: 17 additions & 3 deletions mlir/test/python/ir/blocks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import io
import itertools
from mlir.ir import *

from mlir.dialects import builtin
from mlir.dialects import cf
from mlir.dialects import func
from mlir.ir import *


def run(f):
Expand Down Expand Up @@ -54,10 +53,25 @@ def testBlockCreation():
with InsertionPoint(middle_block) as middle_ip:
assert middle_ip.block == middle_block
cf.BranchOp([i32_arg], dest=successor_block)

module.print(enable_debug_info=True)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region

assert len(entry_block.predecessors) == 0

assert len(entry_block.successors) == 1
assert middle_block == entry_block.successors[0]
assert len(middle_block.predecessors) == 1
assert entry_block == middle_block.predecessors[0]

assert len(middle_block.successors) == 1
assert successor_block == middle_block.successors[0]
assert len(successor_block.predecessors) == 1
assert middle_block == successor_block.predecessors[0]

assert len(successor_block.successors) == 0


# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
Expand Down