-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms] Support moveOpBefore
/After
in dialect conversion
#81240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Support moveOpBefore
/After
in dialect conversion
#81240
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd a new rewrite action for "operation movements". This action can roll back
Full diff: https://github.com/llvm/llvm-project/pull/81240.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d..b8aeea0d23475 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
- virtual void moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator);
+ void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
- virtual void moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator);
+ void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this block and insert it right before `existingBlock`.
void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f061d761ecefb..c0c702a7d3482 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
- void moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator) override;
- void moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator) override;
-
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 44c107c8733f3..ffdb069f6e9b8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -757,7 +757,8 @@ class RewriteAction {
InlineBlock,
MoveBlock,
SplitBlock,
- BlockTypeConversion
+ BlockTypeConversion,
+ MoveOperation
};
virtual ~RewriteAction() = default;
@@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction {
void rollback() override;
};
+
+/// An operation rewrite.
+class OperationAction : public RewriteAction {
+public:
+ /// Return the operation that this action operates on.
+ Operation *getOperation() const { return op; }
+
+ static bool classof(const RewriteAction *action) {
+ return action->getKind() >= Kind::MoveOperation &&
+ action->getKind() <= Kind::MoveOperation;
+ }
+
+protected:
+ OperationAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : RewriteAction(kind, rewriterImpl), op(op) {}
+
+ // The operation that this action operates on.
+ Operation *op;
+};
+
+/// Rewrite action that represent the moving of a block.
+class MoveOperationAction : public OperationAction {
+public:
+ MoveOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, Block *block, Operation *insertBeforeOp)
+ : OperationAction(Kind::MoveOperation, rewriterImpl, op), block(block),
+ insertBeforeOp(insertBeforeOp) {}
+
+ static bool classof(const RewriteAction *action) {
+ return action->getKind() == Kind::MoveOperation;
+ }
+
+ void rollback() override {
+ // Move the operation back to its original position.
+ Block::iterator before =
+ insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+ block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+ }
+
+private:
+ // The block in which this operation was previously contained.
+ Block *block;
+
+ // The original successor of this operation before it was moved. "nullptr" if
+ // this operation was the only operation in the region.
+ Operation *insertBeforeOp;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1468,12 +1517,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
- assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
- createdOps.push_back(op);
+ if (!previous.isSet()) {
+ // This is a newly created op.
+ createdOps.push_back(op);
+ return;
+ }
+ Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+ ? nullptr
+ : &*previous.getPoint();
+ appendRewriteAction<MoveOperationAction>(op, previous.getBlock(), prevOp);
}
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1712,18 +1768,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator) {
- llvm_unreachable(
- "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator) {
- llvm_unreachable(
- "moving single ops is not supported in a dialect conversion");
-}
-
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719ced..84fcc18ab7d37 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
return
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+ // CHECK: "test.one_region_op"()
+ // CHECK: "test.hoist_me"()
+ "test.one_region_op"() ({
+ // expected-remark @below{{'test.hoist_me' is not legalizable}}
+ %0 = "test.hoist_me"() : () -> (i32)
+ "test.valid"(%0) : (i32) -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1f..1c02232b8adbb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+ TestUndoMoveOpBefore(MLIRContext *ctx)
+ : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.moveOpBefore(op, op->getParentOp());
+ // Replace with an illegal op to ensure the conversion fails.
+ rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+ return success();
+ }
+};
+
/// A rewrite pattern that tests the undo mechanism when erasing a block.
struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
- TestCreateUnregisteredOp>(&getContext());
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp>();
+ TerminatorOp, OneRegionOp>();
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
|
@@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction { | |||
|
|||
void rollback() override; | |||
}; | |||
|
|||
/// An operation rewrite. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand on the role of the class, the context where it's used?
The "Action" name for this whole section is not great by the way, since the concept of "Actions" is now core to MLIR tracing...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed the class to OperationRewrite
. I added more documentation to the superclass IRRewrite
in the dependent PR.
7503c0c
to
ebfaca6
Compare
ae9dcbb
to
c60c43b
Compare
c60c43b
to
25fe429
Compare
ebfaca6
to
1d17c76
Compare
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`. `RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.) BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1d17c76
to
5e261de
Compare
Add a new rewrite class for "operation movements". This rewrite class can roll back
moveOpBefore
andmoveOpAfter
.RewriterBase::moveOpBefore
andRewriterBase::moveOpAfter
is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)