Skip to content

Commit 1d17c76

Browse files
[mlir][Transforms] Support moveOpBefore/After in dialect conversion
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`. `RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.) BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent 25fe429 commit 1d17c76

File tree

5 files changed

+95
-28
lines changed

5 files changed

+95
-28
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
588588

589589
/// Unlink this operation from its current block and insert it right before
590590
/// `iterator` in the specified block.
591-
virtual void moveOpBefore(Operation *op, Block *block,
592-
Block::iterator iterator);
591+
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
593592

594593
/// Unlink this operation from its current block and insert it right after
595594
/// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
598597

599598
/// Unlink this operation from its current block and insert it right after
600599
/// `iterator` in the specified block.
601-
virtual void moveOpAfter(Operation *op, Block *block,
602-
Block::iterator iterator);
600+
void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
603601

604602
/// Unlink this block and insert it right before `existingBlock`.
605603
void moveBlockBefore(Block *block, Block *anotherBlock);

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
744744

745745
/// PatternRewriter hook for updating the given operation in-place.
746746
/// Note: These methods only track updates to the given operation itself,
747-
/// and not nested regions. Updates to regions will still require notification
748-
/// through other more specific hooks above.
747+
/// and not nested regions. Updates to regions will still require
748+
/// notification through other more specific hooks above.
749749
void startOpModification(Operation *op) override;
750750

751751
/// PatternRewriter hook for updating the given operation in-place.
@@ -761,11 +761,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
761761
// Hide unsupported pattern rewriter API.
762762
using OpBuilder::setListener;
763763

764-
void moveOpBefore(Operation *op, Block *block,
765-
Block::iterator iterator) override;
766-
void moveOpAfter(Operation *op, Block *block,
767-
Block::iterator iterator) override;
768-
769764
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
770765
};
771766

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,8 @@ class IRRewrite {
760760
InlineBlock,
761761
MoveBlock,
762762
SplitBlock,
763-
BlockTypeConversion
763+
BlockTypeConversion,
764+
MoveOperation
764765
};
765766

766767
virtual ~IRRewrite() = default;
@@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite {
982983
// `ArgConverter::applyRewrites`. This should be done in the "commit" method.
983984
void rollback() override;
984985
};
986+
987+
/// An operation rewrite.
988+
class OperationRewrite : public IRRewrite {
989+
public:
990+
/// Return the operation that this rewrite operates on.
991+
Operation *getOperation() const { return op; }
992+
993+
static bool classof(const IRRewrite *rewrite) {
994+
return rewrite->getKind() >= Kind::MoveOperation &&
995+
rewrite->getKind() <= Kind::MoveOperation;
996+
}
997+
998+
protected:
999+
OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
1000+
Operation *op)
1001+
: IRRewrite(kind, rewriterImpl), op(op) {}
1002+
1003+
// The operation that this rewrite operates on.
1004+
Operation *op;
1005+
};
1006+
1007+
/// Moving of an operation. This rewrite is immediately reflected in the IR.
1008+
class MoveOperationRewrite : public OperationRewrite {
1009+
public:
1010+
MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
1011+
Operation *op, Block *block, Operation *insertBeforeOp)
1012+
: OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
1013+
insertBeforeOp(insertBeforeOp) {}
1014+
1015+
static bool classof(const IRRewrite *rewrite) {
1016+
return rewrite->getKind() == Kind::MoveOperation;
1017+
}
1018+
1019+
void rollback() override {
1020+
// Move the operation back to its original position.
1021+
Block::iterator before =
1022+
insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
1023+
block->getOperations().splice(before, op->getBlock()->getOperations(), op);
1024+
}
1025+
1026+
private:
1027+
// The block in which this operation was previously contained.
1028+
Block *block;
1029+
1030+
// The original successor of this operation before it was moved. "nullptr" if
1031+
// this operation was the only operation in the region.
1032+
Operation *insertBeforeOp;
1033+
};
9851034
} // namespace
9861035

9871036
//===----------------------------------------------------------------------===//
@@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
14781527

14791528
void ConversionPatternRewriterImpl::notifyOperationInserted(
14801529
Operation *op, OpBuilder::InsertPoint previous) {
1481-
assert(!previous.isSet() && "expected newly created op");
14821530
LLVM_DEBUG({
14831531
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
14841532
<< ")\n";
14851533
});
1486-
createdOps.push_back(op);
1534+
if (!previous.isSet()) {
1535+
// This is a newly created op.
1536+
createdOps.push_back(op);
1537+
return;
1538+
}
1539+
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1540+
? nullptr
1541+
: &*previous.getPoint();
1542+
appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
14871543
}
14881544

14891545
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
17221778
rootUpdates.erase(rootUpdates.begin() + updateIdx);
17231779
}
17241780

1725-
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
1726-
Block::iterator iterator) {
1727-
llvm_unreachable(
1728-
"moving single ops is not supported in a dialect conversion");
1729-
}
1730-
1731-
void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
1732-
Block::iterator iterator) {
1733-
llvm_unreachable(
1734-
"moving single ops is not supported in a dialect conversion");
1735-
}
1736-
17371781
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
17381782
return *impl;
17391783
}

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,17 @@ module {
320320
return
321321
}
322322
}
323+
324+
// -----
325+
326+
// CHECK-LABEL: func @test_move_op_before_rollback()
327+
func.func @test_move_op_before_rollback() {
328+
// CHECK: "test.one_region_op"()
329+
// CHECK: "test.hoist_me"()
330+
"test.one_region_op"() ({
331+
// expected-remark @below{{'test.hoist_me' is not legalizable}}
332+
%0 = "test.hoist_me"() : () -> (i32)
333+
"test.valid"(%0) : (i32) -> ()
334+
}) : () -> ()
335+
"test.return"() : () -> ()
336+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
773773
}
774774
};
775775

776+
/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
777+
/// This is to test the rollback logic.
778+
struct TestUndoMoveOpBefore : public ConversionPattern {
779+
TestUndoMoveOpBefore(MLIRContext *ctx)
780+
: ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
781+
782+
LogicalResult
783+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
784+
ConversionPatternRewriter &rewriter) const override {
785+
rewriter.moveOpBefore(op, op->getParentOp());
786+
// Replace with an illegal op to ensure the conversion fails.
787+
rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
788+
return success();
789+
}
790+
};
791+
776792
/// A rewrite pattern that tests the undo mechanism when erasing a block.
777793
struct TestUndoBlockErase : public ConversionPattern {
778794
TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
10691085
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
10701086
TestNonRootReplacement, TestBoundedRecursiveRewrite,
10711087
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1072-
TestCreateUnregisteredOp>(&getContext());
1088+
TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
10731089
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
10741090
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
10751091
converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
10791095
ConversionTarget target(getContext());
10801096
target.addLegalOp<ModuleOp>();
10811097
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1082-
TerminatorOp>();
1098+
TerminatorOp, OneRegionOp>();
10831099
target
10841100
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
10851101
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {

0 commit comments

Comments
 (0)