Skip to content

[mlir][IR] Make replaceOp / replaceAllUsesWith API consistent #82629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 42 additions & 45 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
Region::iterator before);
void inlineRegionBefore(Region &region, Block *before);

/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
/// the uses of `op` were replaced. Note that in some rewriters, the given
/// 'functor' may be stored beyond the lifetime of the rewrite being applied.
/// As such, the function should not capture by reference and instead use
/// value capture as necessary.
virtual void
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor);
void replaceOpWithIf(Operation *op, ValueRange newValues,
llvm::unique_function<bool(OpOperand &) const> functor) {
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
std::move(functor));
}

/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
bool *allUsesReplaced = nullptr);

/// This method replaces the results of the operation with the specified list
/// of values. The number of provided values must match the number of results
/// of the operation. The replaced op is erased.
/// Replace the results of the given (original) operation with the specified
/// list of values (replacements). The result types of the given op and the
/// replacements must match. The original op is erased.
virtual void replaceOp(Operation *op, ValueRange newValues);

/// This method replaces the results of the operation with the specified
/// new op (replacement). The number of results of the two operations must
/// match. The replaced op is erased.
/// Replace the results of the given (original) operation with the specified
/// new op (replacement). The result types of the two ops must match. The
/// original op is erased.
virtual void replaceOp(Operation *op, Operation *newOp);

/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
/// Replace the results of the given (original) op with a new op that is
/// created without verification (replacement). The result values of the two
/// ops must match. The original op is erased.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
Expand Down Expand Up @@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
finalizeOpModification(root);
}

/// Find uses of `from` and replace them with `to`. It also marks every
/// modified uses and notifies the rewriter that an in-place operation
/// modification is about to happen.
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
void replaceAllUsesWith(Value from, Value to) {
return replaceAllUsesWith(from.getImpl(), to);
}
Expand All @@ -652,30 +628,51 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
void replaceAllUsesWith(Operation *from, ValueRange to) {
replaceAllUsesWith(from->getResults(), to);
}

/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
/// in-place operation modification is about to happen.
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor);
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor) {
assert(from.size() == to.size() && "incorrect number of replacements");
for (auto it : llvm::zip(from, to))
replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(Operation *from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr) {
replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
}

/// Find uses of `from` within `block` and replace them with `to`. Also notify
/// the listener about every in-place op modification (for every use that was
/// replaced). The optional `allUsesReplaced` flag is set to "true" if all
/// uses were replaced.
void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
bool *allUsesReplaced = nullptr) {
replaceUsesWithIf(
op, newValues,
[block](OpOperand &use) {
return block->getParentOp()->isProperAncestor(use.getOwner());
},
allUsesReplaced);
}

/// Find uses of `from` and replace them with `to` except if the user is
/// `exceptedUser`. It also marks every modified uses and notifies the
/// rewriter that an in-place operation modification is about to happen.
/// `exceptedUser`. Also notify the listener about every in-place op
/// modification (for every use that was replaced).
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
return replaceUsesWithIf(from, to, [&](OpOperand &use) {
Operation *user = use.getOwner();
return user != exceptedUser;
});
}

/// Used to notify the rewriter that the IR failed to be rewritten because of
/// Used to notify the listener that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
/// to optionally hook into the reason why a rewrite failed, and display it to
Expand Down
6 changes: 0 additions & 6 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,12 +720,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }

/// PatternRewriter hook for replacing an operation when the given functor
/// returns "true".
void replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) override;

/// PatternRewriter hook for replacing an operation.
void replaceOp(Operation *op, ValueRange newValues) override;

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
scalarReplacements.push_back(
residualGenericOpBody->getArgument(num + origNumInputs));
bool allUsesReplaced = false;
rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
residualGenericOpBody, &allUsesReplaced);
rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
residualGenericOpBody, &allUsesReplaced);
assert(!allUsesReplaced &&
"peeled scalar operation is erased when it wasnt expected to be");
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
Value materialized =
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
return use.getOwner() != materialized.getDefiningOp();
});
}
Expand Down
73 changes: 27 additions & 46 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}

/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`.
void RewriterBase::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
assert(op->getNumResults() == newValues.size() &&
"incorrect number of values to replace operation");

// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);

// Replace each use of the results when the functor is true.
bool replacedAllUses = true;
for (auto it : llvm::zip(op->getResults(), newValues)) {
replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
replacedAllUses &= std::get<0>(it).use_empty();
}
if (allUsesReplaced)
*allUsesReplaced = replacedAllUses;
}

/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
Block *block, bool *allUsesReplaced) {
replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
return block->getParentOp()->isProperAncestor(use.getOwner());
});
}

/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
Expand All @@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);

// Replace results one-by-one. Also notifies the listener of modifications.
for (auto it : llvm::zip(op->getResults(), newValues))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
// Replace all result uses. Also notifies the listener of modifications.
replaceAllUsesWith(op, newValues);

// Erase op and notify listener.
eraseOp(op);
Expand All @@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newOp);

// Replace results one-by-one. Also notifies the listener of modifications.
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
// Replace all result uses. Also notifies the listener of modifications.
replaceAllUsesWith(op, newOp->getResults());

// Erase op and notify listener.
eraseOp(op);
Expand Down Expand Up @@ -279,15 +242,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
rewriteListener->notifyOperationModified(op);
}

/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
/// in-place operation modification is about to happen.
void RewriterBase::replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor) {
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced) {
bool allReplaced = true;
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
if (functor(operand))
bool replace = functor(operand);
if (replace)
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
allReplaced &= replace;
}
if (allUsesReplaced)
*allUsesReplaced = allReplaced;
}

void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced) {
assert(from.size() == to.size() && "incorrect number of replacements");
bool allReplaced = true;
for (auto it : llvm::zip_equal(from, to)) {
bool r;
replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
/*allUsesReplaced=*/&r);
allReplaced &= r;
}
if (allUsesReplaced)
*allUsesReplaced = allReplaced;
}

void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
Expand Down
13 changes: 0 additions & 13 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,19 +1528,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(

ConversionPatternRewriter::~ConversionPatternRewriter() = default;

void ConversionPatternRewriter::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
// TODO: To support this we will need to rework a bit of how replacements are
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: replaceAllUsesWith / replaceUsesWithIf is not supported by the dialect conversion either. I have a prototype that adds support for replaceAllUsesWith. At that point, I will override replaceUsesWithIf in the same way as replaceOpWithIf is at the moment (and put an llvm_unrechable).

// tracked, given that this isn't guranteed to replace all of the uses of an
// operation. The main change is that now an operation can be replaced
// multiple times, in parts. The current "set" based tracking is mainly useful
// for tracking if a replaced operation should be ignored, i.e. if all of the
// uses will be replaced.
llvm_unreachable(
"replaceOpWithIf is currently not supported by DialectConversion");
}

void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
rewriter.setInsertionPointToStart(newEntryBlock);
for (auto *clonedOp : clonedOperations) {
Operation *newOp = rewriter.clone(*clonedOp, map);
rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
}
rewriter.mergeBlocks(
entryBlock, newEntryBlock,
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
OperandRange operands = op.getOperands();

// Replace non-terminator uses with the first operand.
rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
});
// Replace everything else with the second operand if the operation isn't
Expand Down