-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][IR] Make replaceOp
/ replaceAllUsesWith
API consistent
#82629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) Changes
This commit renames Also improve comments. Full diff: https://github.com/llvm/llvm-project/pull/82629.diff 8 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ce3bc3fc2e783..909bd9e99f1481 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
Region::iterator before);
void inlineRegionBefore(Region ®ion, Block *before);
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when the provided `functor` returns true for a specific use.
- /// The number of values in `newValues` is required to match the number of
- /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
- /// the uses of `op` were replaced. Note that in some rewriters, the given
- /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
- /// As such, the function should not capture by reference and instead use
- /// value capture as necessary.
- virtual void
- replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor);
- void replaceOpWithIf(Operation *op, ValueRange newValues,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
- std::move(functor));
- }
-
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when a use is nested within the given `block`. The number of
- /// values in `newValues` is required to match the number of results of `op`.
- /// If all uses of this operation are replaced, the operation is erased.
- void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
- bool *allUsesReplaced = nullptr);
-
- /// This method replaces the results of the operation with the specified list
- /// of values. The number of provided values must match the number of results
- /// of the operation. The replaced op is erased.
+ /// Replace the results of the given (original) operation with the specified
+ /// list of values (replacements). The result types of the given op and the
+ /// replacements must match. The original op is erased.
virtual void replaceOp(Operation *op, ValueRange newValues);
- /// This method replaces the results of the operation with the specified
- /// new op (replacement). The number of results of the two operations must
- /// match. The replaced op is erased.
+ /// Replace the results of the given (original) operation with the specified
+ /// new op (replacement). The result types of the two ops must match. The
+ /// original op is erased.
virtual void replaceOp(Operation *op, Operation *newOp);
- /// Replaces the result op with a new op that is created without verification.
- /// The result values of the two ops must be the same types.
+ /// Replace the results of the given (original) op with a new op that is
+ /// created without verification (replacement). The result values of the two
+ /// ops must match. The original op is erased.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
finalizeOpModification(root);
}
- /// Find uses of `from` and replace them with `to`. It also marks every
- /// modified uses and notifies the rewriter that an in-place operation
- /// modification is about to happen.
+ /// Find uses of `from` and replace them with `to`. Also notify the listener
+ /// about every in-place op modification (for every use that was replaced).
void replaceAllUsesWith(Value from, Value to) {
return replaceAllUsesWith(from.getImpl(), to);
}
@@ -652,22 +628,43 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
+ void replaceAllUsesWith(Operation *from, ValueRange to) {
+ replaceAllUsesWith(from->getResults(), to);
+ }
/// Find uses of `from` and replace them with `to` if the `functor` returns
- /// true. It also marks every modified uses and notifies the rewriter that an
- /// in-place operation modification is about to happen.
+ /// true. Also notify the listener about every in-place op modification (for
+ /// every use that was replaced). The optional `allUsesReplaced` flag is set
+ /// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor);
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
- function_ref<bool(OpOperand &)> functor) {
- assert(from.size() == to.size() && "incorrect number of replacements");
- for (auto it : llvm::zip(from, to))
- replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
+ void replaceUsesWithIf(Operation *from, ValueRange to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr) {
+ replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
+ }
+
+ /// Find uses of `from` within `block` and replace them with `to`. Also notify
+ /// the listener about every in-place op modification (for every use that was
+ /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+ /// uses were replaced.
+ void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
+ bool *allUsesReplaced = nullptr) {
+ replaceUsesWithIf(
+ op, newValues,
+ [block](OpOperand &use) {
+ return block->getParentOp()->isProperAncestor(use.getOwner());
+ },
+ allUsesReplaced);
}
/// Find uses of `from` and replace them with `to` except if the user is
- /// `exceptedUser`. It also marks every modified uses and notifies the
- /// rewriter that an in-place operation modification is about to happen.
+ /// `exceptedUser`. Also notify the listener about every in-place op
+ /// modification (for every use that was replaced).
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
return replaceUsesWithIf(from, to, [&](OpOperand &use) {
Operation *user = use.getOwner();
@@ -675,7 +672,7 @@ class RewriterBase : public OpBuilder {
});
}
- /// Used to notify the rewriter that the IR failed to be rewritten because of
+ /// Used to notify the listener that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
/// to optionally hook into the reason why a rewrite failed, and display it to
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2575be4cdea1ac..cde72ba36196c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -719,12 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
- /// PatternRewriter hook for replacing an operation when the given functor
- /// returns "true".
- void replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) override;
-
/// PatternRewriter hook for replacing an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 5cd6d4597affaf..1658ea67a46077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
scalarReplacements.push_back(
residualGenericOpBody->getArgument(num + origNumInputs));
bool allUsesReplaced = false;
- rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
- residualGenericOpBody, &allUsesReplaced);
+ rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
+ residualGenericOpBody, &allUsesReplaced);
assert(!allUsesReplaced &&
"peeled scalar operation is erased when it wasnt expected to be");
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5d220c6cdd7e58..ffe972e5842e29 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -871,7 +871,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
Value materialized =
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
- b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
+ b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
return use.getOwner() != materialized.getDefiningOp();
});
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5ba5328f14b89e..1f8e095bb7a8c3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when the provided `functor` returns true for a specific use.
-/// The number of values in `newValues` is required to match the number of
-/// results of `op`.
-void RewriterBase::replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- assert(op->getNumResults() == newValues.size() &&
- "incorrect number of values to replace operation");
-
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newValues);
-
- // Replace each use of the results when the functor is true.
- bool replacedAllUses = true;
- for (auto it : llvm::zip(op->getResults(), newValues)) {
- replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
- replacedAllUses &= std::get<0>(it).use_empty();
- }
- if (allUsesReplaced)
- *allUsesReplaced = replacedAllUses;
-}
-
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when a use is nested within the given `block`. The number of
-/// values in `newValues` is required to match the number of results of `op`.
-/// If all uses of this operation are replaced, the operation is erased.
-void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
- Block *block, bool *allUsesReplaced) {
- replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
- return block->getParentOp()->isProperAncestor(use.getOwner());
- });
-}
-
/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
- // Replace results one-by-one. Also notifies the listener of modifications.
- for (auto it : llvm::zip(op->getResults(), newValues))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ // Replace all result uses. Also notifies the listener of modifications.
+ replaceAllUsesWith(op, newValues);
// Erase op and notify listener.
eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newOp);
- // Replace results one-by-one. Also notifies the listener of modifications.
- for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ // Replace all result uses. Also notifies the listener of modifications.
+ replaceAllUsesWith(op, newOp->getResults());
// Erase op and notify listener.
eraseOp(op);
@@ -276,15 +239,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
rewriteListener->notifyOperationModified(op);
}
-/// Find uses of `from` and replace them with `to` if the `functor` returns
-/// true. It also marks every modified uses and notifies the rewriter that an
-/// in-place operation modification is about to happen.
void RewriterBase::replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor) {
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ bool allReplaced = true;
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
- if (functor(operand))
+ bool replace = functor(operand);
+ if (replace)
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
+ allReplaced &= replace;
}
+ if (allUsesReplaced)
+ *allUsesReplaced = allReplaced;
+}
+
+void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ assert(from.size() == to.size() && "incorrect number of replacements");
+ bool allReplaced = true;
+ for (auto it : llvm::zip_equal(from, to)) {
+ bool r;
+ replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+ /*allUsesReplaced=*/&r);
+ allReplaced &= r;
+ }
+ if (allUsesReplaced)
+ *allUsesReplaced = allReplaced;
}
void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index db41b9f19e7e8d..239d41d2fa2f41 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1173,7 +1173,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
detachNestedAndErase(op);
}
- // Pop all of the newly created operations.
+ // Pop all of the newly created operations.Patt
while (createdOps.size() != state.numCreatedOps) {
detachNestedAndErase(createdOps.back());
createdOps.pop_back();
@@ -1538,19 +1538,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
-void ConversionPatternRewriter::replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- // TODO: To support this we will need to rework a bit of how replacements are
- // tracked, given that this isn't guranteed to replace all of the uses of an
- // operation. The main change is that now an operation can be replaced
- // multiple times, in parts. The current "set" based tracking is mainly useful
- // for tracking if a replaced operation should be ignored, i.e. if all of the
- // uses will be replaced.
- llvm_unreachable(
- "replaceOpWithIf is currently not supported by DialectConversion");
-}
-
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e8b07143fc60bd..eff8acdfb33d20 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
rewriter.setInsertionPointToStart(newEntryBlock);
for (auto *clonedOp : clonedOperations) {
Operation *newOp = rewriter.clone(*clonedOp, map);
- rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+ rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
}
rewriter.mergeBlocks(
entryBlock, newEntryBlock,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 108cfe8950ef67..0036225ca1b27f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1831,7 +1831,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
OperandRange operands = op.getOperands();
// Replace non-terminator uses with the first operand.
- rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
});
// Replace everything else with the second operand if the operation isn't
|
@llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) Changes
This commit renames Also improve comments. Full diff: https://github.com/llvm/llvm-project/pull/82629.diff 8 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ce3bc3fc2e783..909bd9e99f1481 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
Region::iterator before);
void inlineRegionBefore(Region ®ion, Block *before);
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when the provided `functor` returns true for a specific use.
- /// The number of values in `newValues` is required to match the number of
- /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
- /// the uses of `op` were replaced. Note that in some rewriters, the given
- /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
- /// As such, the function should not capture by reference and instead use
- /// value capture as necessary.
- virtual void
- replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor);
- void replaceOpWithIf(Operation *op, ValueRange newValues,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
- std::move(functor));
- }
-
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when a use is nested within the given `block`. The number of
- /// values in `newValues` is required to match the number of results of `op`.
- /// If all uses of this operation are replaced, the operation is erased.
- void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
- bool *allUsesReplaced = nullptr);
-
- /// This method replaces the results of the operation with the specified list
- /// of values. The number of provided values must match the number of results
- /// of the operation. The replaced op is erased.
+ /// Replace the results of the given (original) operation with the specified
+ /// list of values (replacements). The result types of the given op and the
+ /// replacements must match. The original op is erased.
virtual void replaceOp(Operation *op, ValueRange newValues);
- /// This method replaces the results of the operation with the specified
- /// new op (replacement). The number of results of the two operations must
- /// match. The replaced op is erased.
+ /// Replace the results of the given (original) operation with the specified
+ /// new op (replacement). The result types of the two ops must match. The
+ /// original op is erased.
virtual void replaceOp(Operation *op, Operation *newOp);
- /// Replaces the result op with a new op that is created without verification.
- /// The result values of the two ops must be the same types.
+ /// Replace the results of the given (original) op with a new op that is
+ /// created without verification (replacement). The result values of the two
+ /// ops must match. The original op is erased.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
finalizeOpModification(root);
}
- /// Find uses of `from` and replace them with `to`. It also marks every
- /// modified uses and notifies the rewriter that an in-place operation
- /// modification is about to happen.
+ /// Find uses of `from` and replace them with `to`. Also notify the listener
+ /// about every in-place op modification (for every use that was replaced).
void replaceAllUsesWith(Value from, Value to) {
return replaceAllUsesWith(from.getImpl(), to);
}
@@ -652,22 +628,43 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
+ void replaceAllUsesWith(Operation *from, ValueRange to) {
+ replaceAllUsesWith(from->getResults(), to);
+ }
/// Find uses of `from` and replace them with `to` if the `functor` returns
- /// true. It also marks every modified uses and notifies the rewriter that an
- /// in-place operation modification is about to happen.
+ /// true. Also notify the listener about every in-place op modification (for
+ /// every use that was replaced). The optional `allUsesReplaced` flag is set
+ /// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor);
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
- function_ref<bool(OpOperand &)> functor) {
- assert(from.size() == to.size() && "incorrect number of replacements");
- for (auto it : llvm::zip(from, to))
- replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
+ void replaceUsesWithIf(Operation *from, ValueRange to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr) {
+ replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
+ }
+
+ /// Find uses of `from` within `block` and replace them with `to`. Also notify
+ /// the listener about every in-place op modification (for every use that was
+ /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+ /// uses were replaced.
+ void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
+ bool *allUsesReplaced = nullptr) {
+ replaceUsesWithIf(
+ op, newValues,
+ [block](OpOperand &use) {
+ return block->getParentOp()->isProperAncestor(use.getOwner());
+ },
+ allUsesReplaced);
}
/// Find uses of `from` and replace them with `to` except if the user is
- /// `exceptedUser`. It also marks every modified uses and notifies the
- /// rewriter that an in-place operation modification is about to happen.
+ /// `exceptedUser`. Also notify the listener about every in-place op
+ /// modification (for every use that was replaced).
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
return replaceUsesWithIf(from, to, [&](OpOperand &use) {
Operation *user = use.getOwner();
@@ -675,7 +672,7 @@ class RewriterBase : public OpBuilder {
});
}
- /// Used to notify the rewriter that the IR failed to be rewritten because of
+ /// Used to notify the listener that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
/// to optionally hook into the reason why a rewrite failed, and display it to
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2575be4cdea1ac..cde72ba36196c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -719,12 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
- /// PatternRewriter hook for replacing an operation when the given functor
- /// returns "true".
- void replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) override;
-
/// PatternRewriter hook for replacing an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 5cd6d4597affaf..1658ea67a46077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
scalarReplacements.push_back(
residualGenericOpBody->getArgument(num + origNumInputs));
bool allUsesReplaced = false;
- rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
- residualGenericOpBody, &allUsesReplaced);
+ rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
+ residualGenericOpBody, &allUsesReplaced);
assert(!allUsesReplaced &&
"peeled scalar operation is erased when it wasnt expected to be");
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5d220c6cdd7e58..ffe972e5842e29 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -871,7 +871,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
Value materialized =
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
- b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
+ b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
return use.getOwner() != materialized.getDefiningOp();
});
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5ba5328f14b89e..1f8e095bb7a8c3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when the provided `functor` returns true for a specific use.
-/// The number of values in `newValues` is required to match the number of
-/// results of `op`.
-void RewriterBase::replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- assert(op->getNumResults() == newValues.size() &&
- "incorrect number of values to replace operation");
-
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newValues);
-
- // Replace each use of the results when the functor is true.
- bool replacedAllUses = true;
- for (auto it : llvm::zip(op->getResults(), newValues)) {
- replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
- replacedAllUses &= std::get<0>(it).use_empty();
- }
- if (allUsesReplaced)
- *allUsesReplaced = replacedAllUses;
-}
-
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when a use is nested within the given `block`. The number of
-/// values in `newValues` is required to match the number of results of `op`.
-/// If all uses of this operation are replaced, the operation is erased.
-void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
- Block *block, bool *allUsesReplaced) {
- replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
- return block->getParentOp()->isProperAncestor(use.getOwner());
- });
-}
-
/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
- // Replace results one-by-one. Also notifies the listener of modifications.
- for (auto it : llvm::zip(op->getResults(), newValues))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ // Replace all result uses. Also notifies the listener of modifications.
+ replaceAllUsesWith(op, newValues);
// Erase op and notify listener.
eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newOp);
- // Replace results one-by-one. Also notifies the listener of modifications.
- for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ // Replace all result uses. Also notifies the listener of modifications.
+ replaceAllUsesWith(op, newOp->getResults());
// Erase op and notify listener.
eraseOp(op);
@@ -276,15 +239,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
rewriteListener->notifyOperationModified(op);
}
-/// Find uses of `from` and replace them with `to` if the `functor` returns
-/// true. It also marks every modified uses and notifies the rewriter that an
-/// in-place operation modification is about to happen.
void RewriterBase::replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor) {
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ bool allReplaced = true;
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
- if (functor(operand))
+ bool replace = functor(operand);
+ if (replace)
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
+ allReplaced &= replace;
}
+ if (allUsesReplaced)
+ *allUsesReplaced = allReplaced;
+}
+
+void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ assert(from.size() == to.size() && "incorrect number of replacements");
+ bool allReplaced = true;
+ for (auto it : llvm::zip_equal(from, to)) {
+ bool r;
+ replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+ /*allUsesReplaced=*/&r);
+ allReplaced &= r;
+ }
+ if (allUsesReplaced)
+ *allUsesReplaced = allReplaced;
}
void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index db41b9f19e7e8d..239d41d2fa2f41 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1173,7 +1173,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
detachNestedAndErase(op);
}
- // Pop all of the newly created operations.
+ // Pop all of the newly created operations.Patt
while (createdOps.size() != state.numCreatedOps) {
detachNestedAndErase(createdOps.back());
createdOps.pop_back();
@@ -1538,19 +1538,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
-void ConversionPatternRewriter::replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- // TODO: To support this we will need to rework a bit of how replacements are
- // tracked, given that this isn't guranteed to replace all of the uses of an
- // operation. The main change is that now an operation can be replaced
- // multiple times, in parts. The current "set" based tracking is mainly useful
- // for tracking if a replaced operation should be ignored, i.e. if all of the
- // uses will be replaced.
- llvm_unreachable(
- "replaceOpWithIf is currently not supported by DialectConversion");
-}
-
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e8b07143fc60bd..eff8acdfb33d20 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
rewriter.setInsertionPointToStart(newEntryBlock);
for (auto *clonedOp : clonedOperations) {
Operation *newOp = rewriter.clone(*clonedOp, map);
- rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+ rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
}
rewriter.mergeBlocks(
entryBlock, newEntryBlock,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 108cfe8950ef67..0036225ca1b27f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1831,7 +1831,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
OperandRange operands = op.getOperands();
// Replace non-terminator uses with the first operand.
- rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
});
// Replace everything else with the second operand if the operation isn't
|
void ConversionPatternRewriter::replaceOpWithIf( | ||
Operation *op, ValueRange newValues, bool *allUsesReplaced, | ||
llvm::unique_function<bool(OpOperand &) const> functor) { | ||
// TODO: To support this we will need to rework a bit of how replacements are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: replaceAllUsesWith
/ replaceUsesWithIf
is not supported by the dialect conversion either. I have a prototype that adds support for replaceAllUsesWith
. At that point, I will override replaceUsesWithIf
in the same way as replaceOpWithIf
is at the moment (and put an llvm_unrechable
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like very little usage in tree for these APIs!!
c27d49a
to
89bffa1
Compare
We have a few internal uses, but not many. Maybe this API is not really needed, not sure.. I didn't want to delete anything that others may be using. This change is mainly in preparation of a dialect conversion improvement. I'd like to add support for I wanted to make sure that the rewriter API is consistent first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 8 of 8 files at r1, all commit messages.
Reviewable status: all files reviewed, 2 unresolved discussions (waiting on @dcaballe, @MaheshRavishankar, @matthias-springer, and @nicolasvasilache)
* `replaceOp` replaces all uses of the original op and erases the old op. * `replaceAllUsesWith` replaces all uses of the original op/value/block. It does not erase any IR. This commit renames `replaceOpWithIf` to `replaceUsesWithIf`. `replaceOpWithIf` was a misnomer because the function never erases the original op. Similarly, `replaceOpWithinBlock` is renamed to `replaceUsesWithinBlock`. Also improve comments. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
89bffa1
to
5074d55
Compare
#82629 added additional overloads to `replaceAllUsesWith` and `replaceUsesWithIf`. This caused a build breakage with MSVC when called with ops that can implicitly convert to `Value`. ``` external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(881): error C2666: 'mlir::RewriterBase::replaceAllUsesWith': 2 overloads have similar conversions external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(631): note: could be 'void mlir::RewriterBase::replaceAllUsesWith(mlir::Operation *,mlir::ValueRange)' external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(626): note: or 'void mlir::RewriterBase::replaceAllUsesWith(mlir::ValueRange,mlir::ValueRange)' external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(616): note: or 'void mlir::RewriterBase::replaceAllUsesWith(mlir::Value,mlir::Value)' external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(882): note: while trying to match the argument list '(mlir::tensor::ExtractSliceOp, T)' with [ T=mlir::Value ] ``` Note: The LLVM build bots (Linux and Windows) did not break, this seems to be an issue with `Tools\MSVC\14.29.30133\bin\HostX64\x64\cl.exe`. This change renames the newly added overloads to `replaceAllOpUsesWith` and `replaceOpUsesWithIf`.
#82629 added additional overloads to `replaceAllUsesWith` and `replaceUsesWithIf`. This caused a build breakage with MSVC when called with ops that can implicitly convert to `Value`. ``` external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(881): error C2666: 'mlir::RewriterBase::replaceAllUsesWith': 2 overloads have similar conversions external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(631): note: could be 'void mlir::RewriterBase::replaceAllUsesWith(mlir::Operation *,mlir::ValueRange)' external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(626): note: or 'void mlir::RewriterBase::replaceAllUsesWith(mlir::ValueRange,mlir::ValueRange)' external/llvm-project/mlir/include\mlir/IR/PatternMatch.h(616): note: or 'void mlir::RewriterBase::replaceAllUsesWith(mlir::Value,mlir::Value)' external/llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp(882): note: while trying to match the argument list '(mlir::tensor::ExtractSliceOp, T)' with [ T=mlir::Value ] ``` Note: The LLVM build bots (Linux and Windows) did not break, this seems to be an issue with `Tools\MSVC\14.29.30133\bin\HostX64\x64\cl.exe`. This change renames the newly added overloads to `replaceAllOpUsesWith` and `replaceOpUsesWithIf`.
replaceOp
replaces all uses of the original op and erases the old op.replaceAllUsesWith
replaces all uses of the original op/value/block. It does not erase any IR.This commit renames
replaceOpWithIf
toreplaceUsesWithIf
.replaceOpWithIf
was a misnomer because the function never erases the original op. Similarly,replaceOpWithinBlock
is renamed toreplaceUsesWithinBlock
. (No "operation replaced" is sent because the op is not erased.)Also improve comments.