Skip to content

[mlir][Transforms] Support moveOpBefore/After in dialect conversion #81240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Feb 9, 2024

Add a new rewrite class for "operation movements". This rewrite class can roll back moveOpBefore and moveOpAfter.

RewriterBase::moveOpBefore and RewriterBase::moveOpAfter is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 9, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add a new rewrite action for "operation movements". This action can roll back moveOpBefore and moveOpAfter.

RewriterBase::moveOpBefore and RewriterBase::moveOpAfter is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)


Full diff: https://github.com/llvm/llvm-project/pull/81240.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+2-4)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (-5)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+59-15)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+14)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+18-2)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d..b8aeea0d23475 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right before
   /// `iterator` in the specified block.
-  virtual void moveOpBefore(Operation *op, Block *block,
-                            Block::iterator iterator);
+  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this operation from its current block and insert it right after
   /// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right after
   /// `iterator` in the specified block.
-  virtual void moveOpAfter(Operation *op, Block *block,
-                           Block::iterator iterator);
+  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this block and insert it right before `existingBlock`.
   void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f061d761ecefb..c0c702a7d3482 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
 
-  void moveOpBefore(Operation *op, Block *block,
-                    Block::iterator iterator) override;
-  void moveOpAfter(Operation *op, Block *block,
-                   Block::iterator iterator) override;
-
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 44c107c8733f3..ffdb069f6e9b8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -757,7 +757,8 @@ class RewriteAction {
     InlineBlock,
     MoveBlock,
     SplitBlock,
-    BlockTypeConversion
+    BlockTypeConversion,
+    MoveOperation
   };
 
   virtual ~RewriteAction() = default;
@@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction {
 
   void rollback() override;
 };
+
+/// An operation rewrite.
+class OperationAction : public RewriteAction {
+public:
+  /// Return the operation that this action operates on.
+  Operation *getOperation() const { return op; }
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() >= Kind::MoveOperation &&
+           action->getKind() <= Kind::MoveOperation;
+  }
+
+protected:
+  OperationAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+                  Operation *op)
+      : RewriteAction(kind, rewriterImpl), op(op) {}
+
+  // The operation that this action operates on.
+  Operation *op;
+};
+
+/// Rewrite action that represent the moving of a block.
+class MoveOperationAction : public OperationAction {
+public:
+  MoveOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+                      Operation *op, Block *block, Operation *insertBeforeOp)
+      : OperationAction(Kind::MoveOperation, rewriterImpl, op), block(block),
+        insertBeforeOp(insertBeforeOp) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::MoveOperation;
+  }
+
+  void rollback() override {
+    // Move the operation back to its original position.
+    Block::iterator before =
+        insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+    block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+  }
+
+private:
+  // The block in which this operation was previously contained.
+  Block *block;
+
+  // The original successor of this operation before it was moved. "nullptr" if
+  // this operation was the only operation in the region.
+  Operation *insertBeforeOp;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1468,12 +1517,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 
 void ConversionPatternRewriterImpl::notifyOperationInserted(
     Operation *op, OpBuilder::InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  createdOps.push_back(op);
+  if (!previous.isSet()) {
+    // This is a newly created op.
+    createdOps.push_back(op);
+    return;
+  }
+  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+                          ? nullptr
+                          : &*previous.getPoint();
+  appendRewriteAction<MoveOperationAction>(op, previous.getBlock(), prevOp);
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1712,18 +1768,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
-                                             Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
-                                            Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719ced..84fcc18ab7d37 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+  // CHECK: "test.one_region_op"()
+  // CHECK: "test.hoist_me"()
+  "test.one_region_op"() ({
+    // expected-remark @below{{'test.hoist_me' is not legalizable}}
+    %0 = "test.hoist_me"() : () -> (i32)
+    "test.valid"(%0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1f..1c02232b8adbb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
   }
 };
 
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+  TestUndoMoveOpBefore(MLIRContext *ctx)
+      : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.moveOpBefore(op, op->getParentOp());
+    // Replace with an illegal op to ensure the conversion fails.
+    rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+    return success();
+  }
+};
+
 /// A rewrite pattern that tests the undo mechanism when erasing a block.
 struct TestUndoBlockErase : public ConversionPattern {
   TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp>(&getContext());
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp>();
+                      TerminatorOp, OneRegionOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {

@@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction {

void rollback() override;
};

/// An operation rewrite.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on the role of the class, the context where it's used?

The "Action" name for this whole section is not great by the way, since the concept of "Actions" is now core to MLIR tracing...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to OperationRewrite. I added more documentation to the superclass IRRewrite in the dependent PR.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_move_op_before branch from 7503c0c to ebfaca6 Compare February 12, 2024 09:09
@matthias-springer matthias-springer force-pushed the users/matthias-springer/rewrite_action branch from ae9dcbb to c60c43b Compare February 12, 2024 09:10
@matthias-springer matthias-springer force-pushed the users/matthias-springer/rewrite_action branch from c60c43b to 25fe429 Compare February 14, 2024 16:07
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_move_op_before branch from ebfaca6 to 1d17c76 Compare February 14, 2024 16:12
Base automatically changed from users/matthias-springer/rewrite_action to main February 14, 2024 16:15
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.

`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_move_op_before branch from 1d17c76 to 5e261de Compare February 14, 2024 16:33
@matthias-springer matthias-springer merged commit 8f4cd2c into main Feb 14, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conversion_move_op_before branch February 14, 2024 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants