Skip to content

Commit 9c7b0c4

Browse files
committed
[MLIR] Add PatternRewriter::mergeBlockBefore() to merge a block in the middle of another block.
- This utility to merge a block anywhere into another one can help inline single block regions into other blocks. - Modified patterns test to use the new function. Differential Revision: https://reviews.llvm.org/D86251
1 parent 724f570 commit 9c7b0c4

File tree

5 files changed

+32
-9
lines changed

5 files changed

+32
-9
lines changed

mlir/include/mlir/IR/Block.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
207207
}
208208

209209
/// Return true if this block has no predecessors.
210-
bool hasNoPredecessors();
210+
bool hasNoPredecessors() { return pred_begin() == pred_end(); }
211+
212+
/// Returns true if this blocks has no successors.
213+
bool hasNoSuccessors() { return succ_begin() == succ_end(); }
211214

212215
/// If this block has exactly one predecessor, return it. Otherwise, return
213216
/// null.

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,11 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
326326
virtual void mergeBlocks(Block *source, Block *dest,
327327
ValueRange argValues = llvm::None);
328328

329+
// Merge the operations of block 'source' before the operation 'op'. Source
330+
// block should not have existing predecessors or successors.
331+
void mergeBlockBefore(Block *source, Operation *op,
332+
ValueRange argValues = llvm::None);
333+
329334
/// Split the operations starting at "before" (inclusive) out of the given
330335
/// block into a new block, and return it.
331336
virtual Block *splitBlock(Block *block, Block::iterator before);

mlir/lib/IR/Block.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,6 @@ Operation *Block::getTerminator() {
201201
return &back();
202202
}
203203

204-
/// Return true if this block has no predecessors.
205-
bool Block::hasNoPredecessors() { return pred_begin() == pred_end(); }
206-
207204
// Indexed successor access.
208205
unsigned Block::getNumSuccessors() {
209206
return empty() ? 0 : back().getNumSuccessors();

mlir/lib/IR/PatternMatch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,28 @@ void PatternRewriter::mergeBlocks(Block *source, Block *dest,
128128
source->erase();
129129
}
130130

131+
// Merge the operations of block 'source' before the operation 'op'. Source
132+
// block should not have existing predecessors or successors.
133+
void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
134+
ValueRange argValues) {
135+
assert(source->hasNoPredecessors() &&
136+
"expected 'source' to have no predecessors");
137+
assert(source->hasNoSuccessors() &&
138+
"expected 'source' to have no successors");
139+
140+
// Split the block containing 'op' into two, one containg all operations
141+
// before 'op' (prologue) and another (epilogue) containing 'op' and all
142+
// operations after it.
143+
Block *prologue = op->getBlock();
144+
Block *epilogue = splitBlock(prologue, op->getIterator());
145+
146+
// Merge the source block at the end of the prologue.
147+
mergeBlocks(source, prologue, argValues);
148+
149+
// Merge the epilogue at the end the prologue.
150+
mergeBlocks(epilogue, prologue);
151+
}
152+
131153
/// Split the operations starting at "before" (inclusive) out of the given
132154
/// block into a new block, and return it.
133155
Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -893,16 +893,12 @@ struct TestMergeSingleBlockOps
893893
op.getParentOfType<SingleBlockImplicitTerminatorOp>();
894894
if (!parentOp)
895895
return failure();
896-
Block &parentBlock = parentOp.region().front();
897896
Block &innerBlock = op.region().front();
898897
TerminatorOp innerTerminator =
899898
cast<TerminatorOp>(innerBlock.getTerminator());
900-
Block *parentPrologue =
901-
rewriter.splitBlock(&parentBlock, Block::iterator(op));
899+
rewriter.mergeBlockBefore(&innerBlock, op);
902900
rewriter.eraseOp(innerTerminator);
903-
rewriter.mergeBlocks(&innerBlock, &parentBlock, {});
904901
rewriter.eraseOp(op);
905-
rewriter.mergeBlocks(parentPrologue, &parentBlock, {});
906902
rewriter.updateRootInPlace(op, [] {});
907903
return success();
908904
}

0 commit comments

Comments
 (0)