Skip to content

Commit c672b34

Browse files
[mlir][IR] Send missing notifications when inlining a block (#79593)
When a block is inlined into another block, the nested operations are moved into another block and the `notifyOperationInserted` callback should be triggered. This commit adds the missing notifications for: * `RewriterBase::inlineBlockBefore` * `RewriterBase::mergeBlocks`
1 parent ab87426 commit c672b34

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

mlir/lib/IR/PatternMatch.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
317317

318318
// Move operations from the source block to the dest block and erase the
319319
// source block.
320-
dest->getOperations().splice(before, source->getOperations());
320+
if (!listener) {
321+
// Fast path: If no listener is attached, move all operations at once.
322+
dest->getOperations().splice(before, source->getOperations());
323+
} else {
324+
while (!source->empty())
325+
moveOpBefore(&source->front(), dest, before);
326+
}
327+
328+
// Erase the source block.
329+
assert(source->empty() && "expected 'source' to be empty");
321330
eraseBlock(source);
322331
}
323332

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
411411
// CHECK: %[[c13:.*]] = arith.constant 13 : index
412412
%c7 = arith.constant 7 : index
413413
%c13 = arith.constant 13 : index
414-
// CHECK: %[[c2:.*]] = arith.constant 2 : index
415-
// CHECK: %[[c3:.*]] = arith.constant 3 : index
416414
%res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
417415
%c0 = arith.constant 0 : index
418416
%c1 = arith.constant 1 : index

mlir/test/Transforms/test-strict-pattern-driver.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,23 @@ func.func @test_move_op_before() {
249249
}) : () -> ()
250250
return
251251
}
252+
253+
// -----
254+
255+
// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
256+
// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
257+
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
258+
// CHECK-AN-LABEL: func @test_inline_block_before(
259+
// CHECK-AN: test.op_1
260+
// CHECK-AN: test.op_2
261+
// CHECK-AN: test.op_3
262+
// CHECK-AN: test.inline_blocks_into_parent
263+
// CHECK-AN: return
264+
func.func @test_inline_block_before() {
265+
"test.inline_blocks_into_parent"() ({
266+
"test.op_1"() : () -> ()
267+
"test.op_2"() : () -> ()
268+
"test.op_3"() : () -> ()
269+
}) : () -> ()
270+
return
271+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,26 @@ struct MoveBeforeParentOp : public RewritePattern {
213213
}
214214
};
215215

216+
/// This pattern inlines blocks that are nested in
217+
/// "test.inline_blocks_into_parent" into the parent block.
218+
struct InlineBlocksIntoParent : public RewritePattern {
219+
InlineBlocksIntoParent(MLIRContext *context)
220+
: RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
221+
context) {}
222+
223+
LogicalResult matchAndRewrite(Operation *op,
224+
PatternRewriter &rewriter) const override {
225+
bool changed = false;
226+
for (Region &r : op->getRegions()) {
227+
while (!r.empty()) {
228+
rewriter.inlineBlockBefore(&r.front(), op);
229+
changed = true;
230+
}
231+
}
232+
return success(changed);
233+
}
234+
};
235+
216236
struct TestPatternDriver
217237
: public PassWrapper<TestPatternDriver, OperationPass<>> {
218238
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -292,20 +312,22 @@ struct TestStrictPatternDriver
292312
mlir::RewritePatternSet patterns(ctx);
293313
patterns.add<
294314
// clang-format off
295-
InsertSameOp,
296-
ReplaceWithNewOp,
297-
EraseOp,
298315
ChangeBlockOp,
316+
EraseOp,
299317
ImplicitChangeOp,
300-
MoveBeforeParentOp
318+
InlineBlocksIntoParent,
319+
InsertSameOp,
320+
MoveBeforeParentOp,
321+
ReplaceWithNewOp
301322
// clang-format on
302323
>(ctx);
303324
SmallVector<Operation *> ops;
304325
getOperation()->walk([&](Operation *op) {
305326
StringRef opName = op->getName().getStringRef();
306327
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
307328
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
308-
opName == "test.move_before_parent_op") {
329+
opName == "test.move_before_parent_op" ||
330+
opName == "test.inline_blocks_into_parent") {
309331
ops.push_back(op);
310332
}
311333
});

0 commit comments

Comments
 (0)