Skip to content

[mlir][IR] Make replaceOp / replaceAllUsesWith API consistent #82629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2024

Conversation

matthias-springer
Copy link
Member

  • replaceOp replaces all uses of the original op and erases the old op.
  • replaceAllUsesWith replaces all uses of the original op/value/block. It does not erase any IR.

This commit renames replaceOpWithIf to replaceUsesWithIf. replaceOpWithIf was a misnomer because the function never erases the original op. Similarly, replaceOpWithinBlock is renamed to replaceUsesWithinBlock. (No "operation replaced" is sent because the op is not erased.)

Also improve comments.

@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes
  • replaceOp replaces all uses of the original op and erases the old op.
  • replaceAllUsesWith replaces all uses of the original op/value/block. It does not erase any IR.

This commit renames replaceOpWithIf to replaceUsesWithIf. replaceOpWithIf was a misnomer because the function never erases the original op. Similarly, replaceOpWithinBlock is renamed to replaceUsesWithinBlock. (No "operation replaced" is sent because the op is not erased.)

Also improve comments.


Full diff: https://github.com/llvm/llvm-project/pull/82629.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+42-45)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+1-1)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+27-46)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-14)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+1-1)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ce3bc3fc2e783..909bd9e99f1481 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
                           Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when the provided `functor` returns true for a specific use.
-  /// The number of values in `newValues` is required to match the number of
-  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
-  /// the uses of `op` were replaced. Note that in some rewriters, the given
-  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
-  /// As such, the function should not capture by reference and instead use
-  /// value capture as necessary.
-  virtual void
-  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
-                  llvm::unique_function<bool(OpOperand &) const> functor);
-  void replaceOpWithIf(Operation *op, ValueRange newValues,
-                       llvm::unique_function<bool(OpOperand &) const> functor) {
-    replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
-                    std::move(functor));
-  }
-
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when a use is nested within the given `block`. The number of
-  /// values in `newValues` is required to match the number of results of `op`.
-  /// If all uses of this operation are replaced, the operation is erased.
-  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
-                            bool *allUsesReplaced = nullptr);
-
-  /// This method replaces the results of the operation with the specified list
-  /// of values. The number of provided values must match the number of results
-  /// of the operation. The replaced op is erased.
+  /// Replace the results of the given (original) operation with the specified
+  /// list of values (replacements). The result types of the given op and the
+  /// replacements must match. The original op is erased.
   virtual void replaceOp(Operation *op, ValueRange newValues);
 
-  /// This method replaces the results of the operation with the specified
-  /// new op (replacement). The number of results of the two operations must
-  /// match. The replaced op is erased.
+  /// Replace the results of the given (original) operation with the specified
+  /// new op (replacement). The result types of the two ops must match. The
+  /// original op is erased.
   virtual void replaceOp(Operation *op, Operation *newOp);
 
-  /// Replaces the result op with a new op that is created without verification.
-  /// The result values of the two ops must be the same types.
+  /// Replace the results of the given (original) op with a new op that is
+  /// created without verification (replacement). The result values of the two
+  /// ops must match. The original op is erased.
   template <typename OpTy, typename... Args>
   OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
     finalizeOpModification(root);
   }
 
-  /// Find uses of `from` and replace them with `to`. It also marks every
-  /// modified uses and notifies the rewriter that an in-place operation
-  /// modification is about to happen.
+  /// Find uses of `from` and replace them with `to`. Also notify the listener
+  /// about every in-place op modification (for every use that was replaced).
   void replaceAllUsesWith(Value from, Value to) {
     return replaceAllUsesWith(from.getImpl(), to);
   }
@@ -652,22 +628,43 @@ class RewriterBase : public OpBuilder {
     for (auto it : llvm::zip(from, to))
       replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
   }
+  void replaceAllUsesWith(Operation *from, ValueRange to) {
+    replaceAllUsesWith(from->getResults(), to);
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
-  /// true. It also marks every modified uses and notifies the rewriter that an
-  /// in-place operation modification is about to happen.
+  /// true. Also notify the listener about every in-place op modification (for
+  /// every use that was replaced). The optional `allUsesReplaced` flag is set
+  /// to "true" if all uses were replaced.
   void replaceUsesWithIf(Value from, Value to,
-                         function_ref<bool(OpOperand &)> functor);
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr);
   void replaceUsesWithIf(ValueRange from, ValueRange to,
-                         function_ref<bool(OpOperand &)> functor) {
-    assert(from.size() == to.size() && "incorrect number of replacements");
-    for (auto it : llvm::zip(from, to))
-      replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr);
+  void replaceUsesWithIf(Operation *from, ValueRange to,
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr) {
+    replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
+  }
+
+  /// Find uses of `from` within `block` and replace them with `to`. Also notify
+  /// the listener about every in-place op modification (for every use that was
+  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+  /// uses were replaced.
+  void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
+                              bool *allUsesReplaced = nullptr) {
+    replaceUsesWithIf(
+        op, newValues,
+        [block](OpOperand &use) {
+          return block->getParentOp()->isProperAncestor(use.getOwner());
+        },
+        allUsesReplaced);
   }
 
   /// Find uses of `from` and replace them with `to` except if the user is
-  /// `exceptedUser`. It also marks every modified uses and notifies the
-  /// rewriter that an in-place operation modification is about to happen.
+  /// `exceptedUser`. Also notify the listener about every in-place op
+  /// modification (for every use that was replaced).
   void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
     return replaceUsesWithIf(from, to, [&](OpOperand &use) {
       Operation *user = use.getOwner();
@@ -675,7 +672,7 @@ class RewriterBase : public OpBuilder {
     });
   }
 
-  /// Used to notify the rewriter that the IR failed to be rewritten because of
+  /// Used to notify the listener that the IR failed to be rewritten because of
   /// a match failure, and provide a callback to populate a diagnostic with the
   /// reason why the failure occurred. This method allows for derived rewriters
   /// to optionally hook into the reason why a rewrite failed, and display it to
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2575be4cdea1ac..cde72ba36196c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -719,12 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing an operation when the given functor
-  /// returns "true".
-  void replaceOpWithIf(
-      Operation *op, ValueRange newValues, bool *allUsesReplaced,
-      llvm::unique_function<bool(OpOperand &) const> functor) override;
-
   /// PatternRewriter hook for replacing an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 5cd6d4597affaf..1658ea67a46077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
       scalarReplacements.push_back(
           residualGenericOpBody->getArgument(num + origNumInputs));
     bool allUsesReplaced = false;
-    rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
-                                  residualGenericOpBody, &allUsesReplaced);
+    rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
+                                    residualGenericOpBody, &allUsesReplaced);
     assert(!allUsesReplaced &&
            "peeled scalar operation is erased when it wasnt expected to be");
   }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5d220c6cdd7e58..ffe972e5842e29 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -871,7 +871,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
     Value materialized =
         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
-    b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
+    b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
       return use.getOwner() != materialized.getDefiningOp();
     });
   }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5ba5328f14b89e..1f8e095bb7a8c3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
 
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when the provided `functor` returns true for a specific use.
-/// The number of values in `newValues` is required to match the number of
-/// results of `op`.
-void RewriterBase::replaceOpWithIf(
-    Operation *op, ValueRange newValues, bool *allUsesReplaced,
-    llvm::unique_function<bool(OpOperand &) const> functor) {
-  assert(op->getNumResults() == newValues.size() &&
-         "incorrect number of values to replace operation");
-
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newValues);
-
-  // Replace each use of the results when the functor is true.
-  bool replacedAllUses = true;
-  for (auto it : llvm::zip(op->getResults(), newValues)) {
-    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
-    replacedAllUses &= std::get<0>(it).use_empty();
-  }
-  if (allUsesReplaced)
-    *allUsesReplaced = replacedAllUses;
-}
-
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when a use is nested within the given `block`. The number of
-/// values in `newValues` is required to match the number of results of `op`.
-/// If all uses of this operation are replaced, the operation is erased.
-void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
-                                        Block *block, bool *allUsesReplaced) {
-  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
-    return block->getParentOp()->isProperAncestor(use.getOwner());
-  });
-}
-
 /// This method replaces the results of the operation with the specified list of
 /// values. The number of provided values must match the number of results of
 /// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newValues);
 
-  // Replace results one-by-one. Also notifies the listener of modifications.
-  for (auto it : llvm::zip(op->getResults(), newValues))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  // Replace all result uses. Also notifies the listener of modifications.
+  replaceAllUsesWith(op, newValues);
 
   // Erase op and notify listener.
   eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newOp);
 
-  // Replace results one-by-one. Also notifies the listener of modifications.
-  for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  // Replace all result uses. Also notifies the listener of modifications.
+  replaceAllUsesWith(op, newOp->getResults());
 
   // Erase op and notify listener.
   eraseOp(op);
@@ -276,15 +239,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
     rewriteListener->notifyOperationModified(op);
 }
 
-/// Find uses of `from` and replace them with `to` if the `functor` returns
-/// true. It also marks every modified uses and notifies the rewriter that an
-/// in-place operation modification is about to happen.
 void RewriterBase::replaceUsesWithIf(Value from, Value to,
-                                     function_ref<bool(OpOperand &)> functor) {
+                                     function_ref<bool(OpOperand &)> functor,
+                                     bool *allUsesReplaced) {
+  bool allReplaced = true;
   for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
-    if (functor(operand))
+    bool replace = functor(operand);
+    if (replace)
       modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
+    allReplaced &= replace;
   }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
+}
+
+void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
+                                     function_ref<bool(OpOperand &)> functor,
+                                     bool *allUsesReplaced) {
+  assert(from.size() == to.size() && "incorrect number of replacements");
+  bool allReplaced = true;
+  for (auto it : llvm::zip_equal(from, to)) {
+    bool r;
+    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+                      /*allUsesReplaced=*/&r);
+    allReplaced &= r;
+  }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index db41b9f19e7e8d..239d41d2fa2f41 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1173,7 +1173,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     detachNestedAndErase(op);
   }
 
-  // Pop all of the newly created operations.
+  // Pop all of the newly created operations.Patt
   while (createdOps.size() != state.numCreatedOps) {
     detachNestedAndErase(createdOps.back());
     createdOps.pop_back();
@@ -1538,19 +1538,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
-void ConversionPatternRewriter::replaceOpWithIf(
-    Operation *op, ValueRange newValues, bool *allUsesReplaced,
-    llvm::unique_function<bool(OpOperand &) const> functor) {
-  // TODO: To support this we will need to rework a bit of how replacements are
-  // tracked, given that this isn't guranteed to replace all of the uses of an
-  // operation. The main change is that now an operation can be replaced
-  // multiple times, in parts. The current "set" based tracking is mainly useful
-  // for tracking if a replaced operation should be ignored, i.e. if all of the
-  // uses will be replaced.
-  llvm_unreachable(
-      "replaceOpWithIf is currently not supported by DialectConversion");
-}
-
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
   assert(op && newOp && "expected non-null op");
   replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e8b07143fc60bd..eff8acdfb33d20 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
   rewriter.setInsertionPointToStart(newEntryBlock);
   for (auto *clonedOp : clonedOperations) {
     Operation *newOp = rewriter.clone(*clonedOp, map);
-    rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+    rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
   }
   rewriter.mergeBlocks(
       entryBlock, newEntryBlock,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 108cfe8950ef67..0036225ca1b27f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1831,7 +1831,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
     OperandRange operands = op.getOperands();
 
     // Replace non-terminator uses with the first operand.
-    rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+    rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
     });
     // Replace everything else with the second operand if the operation isn't

@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes
  • replaceOp replaces all uses of the original op and erases the old op.
  • replaceAllUsesWith replaces all uses of the original op/value/block. It does not erase any IR.

This commit renames replaceOpWithIf to replaceUsesWithIf. replaceOpWithIf was a misnomer because the function never erases the original op. Similarly, replaceOpWithinBlock is renamed to replaceUsesWithinBlock. (No "operation replaced" is sent because the op is not erased.)

Also improve comments.


Full diff: https://github.com/llvm/llvm-project/pull/82629.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+42-45)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+1-1)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+27-46)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-14)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+1-1)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ce3bc3fc2e783..909bd9e99f1481 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
                           Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when the provided `functor` returns true for a specific use.
-  /// The number of values in `newValues` is required to match the number of
-  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
-  /// the uses of `op` were replaced. Note that in some rewriters, the given
-  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
-  /// As such, the function should not capture by reference and instead use
-  /// value capture as necessary.
-  virtual void
-  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
-                  llvm::unique_function<bool(OpOperand &) const> functor);
-  void replaceOpWithIf(Operation *op, ValueRange newValues,
-                       llvm::unique_function<bool(OpOperand &) const> functor) {
-    replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
-                    std::move(functor));
-  }
-
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when a use is nested within the given `block`. The number of
-  /// values in `newValues` is required to match the number of results of `op`.
-  /// If all uses of this operation are replaced, the operation is erased.
-  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
-                            bool *allUsesReplaced = nullptr);
-
-  /// This method replaces the results of the operation with the specified list
-  /// of values. The number of provided values must match the number of results
-  /// of the operation. The replaced op is erased.
+  /// Replace the results of the given (original) operation with the specified
+  /// list of values (replacements). The result types of the given op and the
+  /// replacements must match. The original op is erased.
   virtual void replaceOp(Operation *op, ValueRange newValues);
 
-  /// This method replaces the results of the operation with the specified
-  /// new op (replacement). The number of results of the two operations must
-  /// match. The replaced op is erased.
+  /// Replace the results of the given (original) operation with the specified
+  /// new op (replacement). The result types of the two ops must match. The
+  /// original op is erased.
   virtual void replaceOp(Operation *op, Operation *newOp);
 
-  /// Replaces the result op with a new op that is created without verification.
-  /// The result values of the two ops must be the same types.
+  /// Replace the results of the given (original) op with a new op that is
+  /// created without verification (replacement). The result values of the two
+  /// ops must match. The original op is erased.
   template <typename OpTy, typename... Args>
   OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
     finalizeOpModification(root);
   }
 
-  /// Find uses of `from` and replace them with `to`. It also marks every
-  /// modified uses and notifies the rewriter that an in-place operation
-  /// modification is about to happen.
+  /// Find uses of `from` and replace them with `to`. Also notify the listener
+  /// about every in-place op modification (for every use that was replaced).
   void replaceAllUsesWith(Value from, Value to) {
     return replaceAllUsesWith(from.getImpl(), to);
   }
@@ -652,22 +628,43 @@ class RewriterBase : public OpBuilder {
     for (auto it : llvm::zip(from, to))
       replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
   }
+  void replaceAllUsesWith(Operation *from, ValueRange to) {
+    replaceAllUsesWith(from->getResults(), to);
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
-  /// true. It also marks every modified uses and notifies the rewriter that an
-  /// in-place operation modification is about to happen.
+  /// true. Also notify the listener about every in-place op modification (for
+  /// every use that was replaced). The optional `allUsesReplaced` flag is set
+  /// to "true" if all uses were replaced.
   void replaceUsesWithIf(Value from, Value to,
-                         function_ref<bool(OpOperand &)> functor);
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr);
   void replaceUsesWithIf(ValueRange from, ValueRange to,
-                         function_ref<bool(OpOperand &)> functor) {
-    assert(from.size() == to.size() && "incorrect number of replacements");
-    for (auto it : llvm::zip(from, to))
-      replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr);
+  void replaceUsesWithIf(Operation *from, ValueRange to,
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr) {
+    replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
+  }
+
+  /// Find uses of `from` within `block` and replace them with `to`. Also notify
+  /// the listener about every in-place op modification (for every use that was
+  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+  /// uses were replaced.
+  void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
+                              bool *allUsesReplaced = nullptr) {
+    replaceUsesWithIf(
+        op, newValues,
+        [block](OpOperand &use) {
+          return block->getParentOp()->isProperAncestor(use.getOwner());
+        },
+        allUsesReplaced);
   }
 
   /// Find uses of `from` and replace them with `to` except if the user is
-  /// `exceptedUser`. It also marks every modified uses and notifies the
-  /// rewriter that an in-place operation modification is about to happen.
+  /// `exceptedUser`. Also notify the listener about every in-place op
+  /// modification (for every use that was replaced).
   void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
     return replaceUsesWithIf(from, to, [&](OpOperand &use) {
       Operation *user = use.getOwner();
@@ -675,7 +672,7 @@ class RewriterBase : public OpBuilder {
     });
   }
 
-  /// Used to notify the rewriter that the IR failed to be rewritten because of
+  /// Used to notify the listener that the IR failed to be rewritten because of
   /// a match failure, and provide a callback to populate a diagnostic with the
   /// reason why the failure occurred. This method allows for derived rewriters
   /// to optionally hook into the reason why a rewrite failed, and display it to
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2575be4cdea1ac..cde72ba36196c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -719,12 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing an operation when the given functor
-  /// returns "true".
-  void replaceOpWithIf(
-      Operation *op, ValueRange newValues, bool *allUsesReplaced,
-      llvm::unique_function<bool(OpOperand &) const> functor) override;
-
   /// PatternRewriter hook for replacing an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 5cd6d4597affaf..1658ea67a46077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
       scalarReplacements.push_back(
           residualGenericOpBody->getArgument(num + origNumInputs));
     bool allUsesReplaced = false;
-    rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
-                                  residualGenericOpBody, &allUsesReplaced);
+    rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
+                                    residualGenericOpBody, &allUsesReplaced);
     assert(!allUsesReplaced &&
            "peeled scalar operation is erased when it wasnt expected to be");
   }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5d220c6cdd7e58..ffe972e5842e29 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -871,7 +871,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
     Value materialized =
         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
-    b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
+    b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
       return use.getOwner() != materialized.getDefiningOp();
     });
   }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5ba5328f14b89e..1f8e095bb7a8c3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
 
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when the provided `functor` returns true for a specific use.
-/// The number of values in `newValues` is required to match the number of
-/// results of `op`.
-void RewriterBase::replaceOpWithIf(
-    Operation *op, ValueRange newValues, bool *allUsesReplaced,
-    llvm::unique_function<bool(OpOperand &) const> functor) {
-  assert(op->getNumResults() == newValues.size() &&
-         "incorrect number of values to replace operation");
-
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newValues);
-
-  // Replace each use of the results when the functor is true.
-  bool replacedAllUses = true;
-  for (auto it : llvm::zip(op->getResults(), newValues)) {
-    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
-    replacedAllUses &= std::get<0>(it).use_empty();
-  }
-  if (allUsesReplaced)
-    *allUsesReplaced = replacedAllUses;
-}
-
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when a use is nested within the given `block`. The number of
-/// values in `newValues` is required to match the number of results of `op`.
-/// If all uses of this operation are replaced, the operation is erased.
-void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
-                                        Block *block, bool *allUsesReplaced) {
-  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
-    return block->getParentOp()->isProperAncestor(use.getOwner());
-  });
-}
-
 /// This method replaces the results of the operation with the specified list of
 /// values. The number of provided values must match the number of results of
 /// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newValues);
 
-  // Replace results one-by-one. Also notifies the listener of modifications.
-  for (auto it : llvm::zip(op->getResults(), newValues))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  // Replace all result uses. Also notifies the listener of modifications.
+  replaceAllUsesWith(op, newValues);
 
   // Erase op and notify listener.
   eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newOp);
 
-  // Replace results one-by-one. Also notifies the listener of modifications.
-  for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  // Replace all result uses. Also notifies the listener of modifications.
+  replaceAllUsesWith(op, newOp->getResults());
 
   // Erase op and notify listener.
   eraseOp(op);
@@ -276,15 +239,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
     rewriteListener->notifyOperationModified(op);
 }
 
-/// Find uses of `from` and replace them with `to` if the `functor` returns
-/// true. It also marks every modified uses and notifies the rewriter that an
-/// in-place operation modification is about to happen.
 void RewriterBase::replaceUsesWithIf(Value from, Value to,
-                                     function_ref<bool(OpOperand &)> functor) {
+                                     function_ref<bool(OpOperand &)> functor,
+                                     bool *allUsesReplaced) {
+  bool allReplaced = true;
   for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
-    if (functor(operand))
+    bool replace = functor(operand);
+    if (replace)
       modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
+    allReplaced &= replace;
   }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
+}
+
+void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
+                                     function_ref<bool(OpOperand &)> functor,
+                                     bool *allUsesReplaced) {
+  assert(from.size() == to.size() && "incorrect number of replacements");
+  bool allReplaced = true;
+  for (auto it : llvm::zip_equal(from, to)) {
+    bool r;
+    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+                      /*allUsesReplaced=*/&r);
+    allReplaced &= r;
+  }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index db41b9f19e7e8d..239d41d2fa2f41 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1173,7 +1173,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     detachNestedAndErase(op);
   }
 
-  // Pop all of the newly created operations.
+  // Pop all of the newly created operations.Patt
   while (createdOps.size() != state.numCreatedOps) {
     detachNestedAndErase(createdOps.back());
     createdOps.pop_back();
@@ -1538,19 +1538,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
-void ConversionPatternRewriter::replaceOpWithIf(
-    Operation *op, ValueRange newValues, bool *allUsesReplaced,
-    llvm::unique_function<bool(OpOperand &) const> functor) {
-  // TODO: To support this we will need to rework a bit of how replacements are
-  // tracked, given that this isn't guranteed to replace all of the uses of an
-  // operation. The main change is that now an operation can be replaced
-  // multiple times, in parts. The current "set" based tracking is mainly useful
-  // for tracking if a replaced operation should be ignored, i.e. if all of the
-  // uses will be replaced.
-  llvm_unreachable(
-      "replaceOpWithIf is currently not supported by DialectConversion");
-}
-
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
   assert(op && newOp && "expected non-null op");
   replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e8b07143fc60bd..eff8acdfb33d20 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
   rewriter.setInsertionPointToStart(newEntryBlock);
   for (auto *clonedOp : clonedOperations) {
     Operation *newOp = rewriter.clone(*clonedOp, map);
-    rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+    rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
   }
   rewriter.mergeBlocks(
       entryBlock, newEntryBlock,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 108cfe8950ef67..0036225ca1b27f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1831,7 +1831,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
     OperandRange operands = op.getOperands();
 
     // Replace non-terminator uses with the first operand.
-    rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+    rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
     });
     // Replace everything else with the second operand if the operation isn't

void ConversionPatternRewriter::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
// TODO: To support this we will need to rework a bit of how replacements are
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: replaceAllUsesWith / replaceUsesWithIf is not supported by the dialect conversion either. I have a prototype that adds support for replaceAllUsesWith. At that point, I will override replaceUsesWithIf in the same way as replaceOpWithIf is at the moment (and put an llvm_unrechable).

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like very little usage in tree for these APIs!!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_op_api branch from c27d49a to 89bffa1 Compare February 23, 2024 08:33
@matthias-springer
Copy link
Member Author

matthias-springer commented Feb 23, 2024

Seems like very little usage in tree for these APIs!!

We have a few internal uses, but not many. Maybe this API is not really needed, not sure.. I didn't want to delete anything that others may be using.

This change is mainly in preparation of a dialect conversion improvement. I'd like to add support for replaceAllUsesWith. This would also simplify the design a bit (e.g., no more special conversion pattern API needed for block arguments like replaceUsesOfBlockArgument). And RewriterBase::replaceOp would no longer be virtual. Instead RewriterBase::replaceAllUsesWith would be virtual.

I wanted to make sure that the rewriter API is consistent first.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 8 of 8 files at r1, all commit messages.
Reviewable status: all files reviewed, 2 unresolved discussions (waiting on @dcaballe, @MaheshRavishankar, @matthias-springer, and @nicolasvasilache)

* `replaceOp` replaces all uses of the original op and erases the old op.
* `replaceAllUsesWith` replaces all uses of the original op/value/block. It does not erase any IR.

This commit renames `replaceOpWithIf` to `replaceUsesWithIf`. `replaceOpWithIf` was a misnomer because the function never erases the original op. Similarly, `replaceOpWithinBlock` is renamed to `replaceUsesWithinBlock`.

Also improve comments.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_op_api branch from 89bffa1 to 5074d55 Compare March 7, 2024 01:08
@matthias-springer matthias-springer merged commit 59a9201 into main Mar 7, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/replace_op_api branch March 7, 2024 01:26
matthias-springer added a commit that referenced this pull request Mar 9, 2024
#82629 added additional overloads to `replaceAllUsesWith` and `replaceUsesWithIf`. This caused a build breakage with MSVC when called with ops that can implicitly convert to `Value`.

```
external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(881): error C2666: 'mlir::RewriterBase::replaceAllUsesWith': 2 overloads have similar conversions
external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(631): note: could be 'void mlir::RewriterBase::replaceAllUsesWith(mlir::Operation *,mlir::ValueRange)'
external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(626): note: or       'void mlir::RewriterBase::replaceAllUsesWith(mlir::ValueRange,mlir::ValueRange)'
external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(616): note: or       'void mlir::RewriterBase::replaceAllUsesWith(mlir::Value,mlir::Value)'
external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(882): note: while trying to match the argument list '(mlir::tensor::ExtractSliceOp, T)'
        with
        [
            T=mlir::Value
        ]
```

Note: The LLVM build bots (Linux and Windows) did not break, this seems to be an issue with `Tools\MSVC\14.29.30133\bin\HostX64\x64\cl.exe`.

This change renames the newly added overloads to `replaceAllOpUsesWith` and `replaceOpUsesWithIf`.
matthias-springer added a commit that referenced this pull request Mar 11, 2024
#82629 added additional overloads to `replaceAllUsesWith` and
`replaceUsesWithIf`. This caused a build breakage with MSVC when called
with ops that can implicitly convert to `Value`.

```
external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(881): error C2666: 'mlir::RewriterBase::replaceAllUsesWith': 2 overloads have similar conversions
external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(631): note: could be 'void mlir::RewriterBase::replaceAllUsesWith(mlir::Operation *,mlir::ValueRange)'
external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(626): note: or       'void mlir::RewriterBase::replaceAllUsesWith(mlir::ValueRange,mlir::ValueRange)'
external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(616): note: or       'void mlir::RewriterBase::replaceAllUsesWith(mlir::Value,mlir::Value)'
external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(882): note: while trying to match the argument list '(mlir::tensor::ExtractSliceOp, T)'
        with
        [
            T=mlir::Value
        ]
```

Note: The LLVM build bots (Linux and Windows) did not break, this seems
to be an issue with `Tools\MSVC\14.29.30133\bin\HostX64\x64\cl.exe`.

This change renames the newly added overloads to `replaceAllOpUsesWith`
and `replaceOpUsesWithIf`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants