Skip to content

Commit 89bffa1

Browse files
[mlir][IR] Make replaceOp / replaceAllUsesWith API consistent
* `replaceOp` replaces all uses of the original op and erases the old op. * `replaceAllUsesWith` replaces all uses of the original op/value/block. It does not erase any IR. This commit renames `replaceOpWithIf` to `replaceUsesWithIf`. `replaceOpWithIf` was a misnomer because the function never erases the original op. Similarly, `replaceOpWithinBlock` is renamed to `replaceUsesWithinBlock`. Also improve comments. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent bcbffd9 commit 89bffa1

File tree

8 files changed

+74
-115
lines changed

8 files changed

+74
-115
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
497497
Region::iterator before);
498498
void inlineRegionBefore(Region &region, Block *before);
499499

500-
/// This method replaces the uses of the results of `op` with the values in
501-
/// `newValues` when the provided `functor` returns true for a specific use.
502-
/// The number of values in `newValues` is required to match the number of
503-
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
504-
/// the uses of `op` were replaced. Note that in some rewriters, the given
505-
/// 'functor' may be stored beyond the lifetime of the rewrite being applied.
506-
/// As such, the function should not capture by reference and instead use
507-
/// value capture as necessary.
508-
virtual void
509-
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
510-
llvm::unique_function<bool(OpOperand &) const> functor);
511-
void replaceOpWithIf(Operation *op, ValueRange newValues,
512-
llvm::unique_function<bool(OpOperand &) const> functor) {
513-
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
514-
std::move(functor));
515-
}
516-
517-
/// This method replaces the uses of the results of `op` with the values in
518-
/// `newValues` when a use is nested within the given `block`. The number of
519-
/// values in `newValues` is required to match the number of results of `op`.
520-
/// If all uses of this operation are replaced, the operation is erased.
521-
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
522-
bool *allUsesReplaced = nullptr);
523-
524-
/// This method replaces the results of the operation with the specified list
525-
/// of values. The number of provided values must match the number of results
526-
/// of the operation. The replaced op is erased.
500+
/// Replace the results of the given (original) operation with the specified
501+
/// list of values (replacements). The result types of the given op and the
502+
/// replacements must match. The original op is erased.
527503
virtual void replaceOp(Operation *op, ValueRange newValues);
528504

529-
/// This method replaces the results of the operation with the specified
530-
/// new op (replacement). The number of results of the two operations must
531-
/// match. The replaced op is erased.
505+
/// Replace the results of the given (original) operation with the specified
506+
/// new op (replacement). The result types of the two ops must match. The
507+
/// original op is erased.
532508
virtual void replaceOp(Operation *op, Operation *newOp);
533509

534-
/// Replaces the result op with a new op that is created without verification.
535-
/// The result values of the two ops must be the same types.
510+
/// Replace the results of the given (original) op with a new op that is
511+
/// created without verification (replacement). The result values of the two
512+
/// ops must match. The original op is erased.
536513
template <typename OpTy, typename... Args>
537514
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
538515
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
634611
finalizeOpModification(root);
635612
}
636613

637-
/// Find uses of `from` and replace them with `to`. It also marks every
638-
/// modified uses and notifies the rewriter that an in-place operation
639-
/// modification is about to happen.
614+
/// Find uses of `from` and replace them with `to`. Also notify the listener
615+
/// about every in-place op modification (for every use that was replaced).
640616
void replaceAllUsesWith(Value from, Value to) {
641617
return replaceAllUsesWith(from.getImpl(), to);
642618
}
@@ -652,30 +628,51 @@ class RewriterBase : public OpBuilder {
652628
for (auto it : llvm::zip(from, to))
653629
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
654630
}
631+
void replaceAllUsesWith(Operation *from, ValueRange to) {
632+
replaceAllUsesWith(from->getResults(), to);
633+
}
655634

656635
/// Find uses of `from` and replace them with `to` if the `functor` returns
657-
/// true. It also marks every modified uses and notifies the rewriter that an
658-
/// in-place operation modification is about to happen.
636+
/// true. Also notify the listener about every in-place op modification (for
637+
/// every use that was replaced). The optional `allUsesReplaced` flag is set
638+
/// to "true" if all uses were replaced.
659639
void replaceUsesWithIf(Value from, Value to,
660-
function_ref<bool(OpOperand &)> functor);
640+
function_ref<bool(OpOperand &)> functor,
641+
bool *allUsesReplaced = nullptr);
661642
void replaceUsesWithIf(ValueRange from, ValueRange to,
662-
function_ref<bool(OpOperand &)> functor) {
663-
assert(from.size() == to.size() && "incorrect number of replacements");
664-
for (auto it : llvm::zip(from, to))
665-
replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
643+
function_ref<bool(OpOperand &)> functor,
644+
bool *allUsesReplaced = nullptr);
645+
void replaceUsesWithIf(Operation *from, ValueRange to,
646+
function_ref<bool(OpOperand &)> functor,
647+
bool *allUsesReplaced = nullptr) {
648+
replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
649+
}
650+
651+
/// Find uses of `from` within `block` and replace them with `to`. Also notify
652+
/// the listener about every in-place op modification (for every use that was
653+
/// replaced). The optional `allUsesReplaced` flag is set to "true" if all
654+
/// uses were replaced.
655+
void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
656+
bool *allUsesReplaced = nullptr) {
657+
replaceUsesWithIf(
658+
op, newValues,
659+
[block](OpOperand &use) {
660+
return block->getParentOp()->isProperAncestor(use.getOwner());
661+
},
662+
allUsesReplaced);
666663
}
667664

668665
/// Find uses of `from` and replace them with `to` except if the user is
669-
/// `exceptedUser`. It also marks every modified uses and notifies the
670-
/// rewriter that an in-place operation modification is about to happen.
666+
/// `exceptedUser`. Also notify the listener about every in-place op
667+
/// modification (for every use that was replaced).
671668
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
672669
return replaceUsesWithIf(from, to, [&](OpOperand &use) {
673670
Operation *user = use.getOwner();
674671
return user != exceptedUser;
675672
});
676673
}
677674

678-
/// Used to notify the rewriter that the IR failed to be rewritten because of
675+
/// Used to notify the listener that the IR failed to be rewritten because of
679676
/// a match failure, and provide a callback to populate a diagnostic with the
680677
/// reason why the failure occurred. This method allows for derived rewriters
681678
/// to optionally hook into the reason why a rewrite failed, and display it to

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,12 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
719719
/// patterns even if a failure is encountered during the rewrite step.
720720
bool canRecoverFromRewriteFailure() const override { return true; }
721721

722-
/// PatternRewriter hook for replacing an operation when the given functor
723-
/// returns "true".
724-
void replaceOpWithIf(
725-
Operation *op, ValueRange newValues, bool *allUsesReplaced,
726-
llvm::unique_function<bool(OpOperand &) const> functor) override;
727-
728722
/// PatternRewriter hook for replacing an operation.
729723
void replaceOp(Operation *op, ValueRange newValues) override;
730724

mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
370370
scalarReplacements.push_back(
371371
residualGenericOpBody->getArgument(num + origNumInputs));
372372
bool allUsesReplaced = false;
373-
rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
374-
residualGenericOpBody, &allUsesReplaced);
373+
rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
374+
residualGenericOpBody, &allUsesReplaced);
375375
assert(!allUsesReplaced &&
376376
"peeled scalar operation is erased when it wasnt expected to be");
377377
}

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
871871
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
872872
Value materialized =
873873
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
874-
b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
874+
b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
875875
return use.getOwner() != materialized.getDefiningOp();
876876
});
877877
}

mlir/lib/IR/PatternMatch.cpp

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
110110
// Out of line to provide a vtable anchor for the class.
111111
}
112112

113-
/// This method replaces the uses of the results of `op` with the values in
114-
/// `newValues` when the provided `functor` returns true for a specific use.
115-
/// The number of values in `newValues` is required to match the number of
116-
/// results of `op`.
117-
void RewriterBase::replaceOpWithIf(
118-
Operation *op, ValueRange newValues, bool *allUsesReplaced,
119-
llvm::unique_function<bool(OpOperand &) const> functor) {
120-
assert(op->getNumResults() == newValues.size() &&
121-
"incorrect number of values to replace operation");
122-
123-
// Notify the listener that we're about to replace this op.
124-
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
125-
rewriteListener->notifyOperationReplaced(op, newValues);
126-
127-
// Replace each use of the results when the functor is true.
128-
bool replacedAllUses = true;
129-
for (auto it : llvm::zip(op->getResults(), newValues)) {
130-
replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
131-
replacedAllUses &= std::get<0>(it).use_empty();
132-
}
133-
if (allUsesReplaced)
134-
*allUsesReplaced = replacedAllUses;
135-
}
136-
137-
/// This method replaces the uses of the results of `op` with the values in
138-
/// `newValues` when a use is nested within the given `block`. The number of
139-
/// values in `newValues` is required to match the number of results of `op`.
140-
/// If all uses of this operation are replaced, the operation is erased.
141-
void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
142-
Block *block, bool *allUsesReplaced) {
143-
replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
144-
return block->getParentOp()->isProperAncestor(use.getOwner());
145-
});
146-
}
147-
148113
/// This method replaces the results of the operation with the specified list of
149114
/// values. The number of provided values must match the number of results of
150115
/// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
156121
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
157122
rewriteListener->notifyOperationReplaced(op, newValues);
158123

159-
// Replace results one-by-one. Also notifies the listener of modifications.
160-
for (auto it : llvm::zip(op->getResults(), newValues))
161-
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
124+
// Replace all result uses. Also notifies the listener of modifications.
125+
replaceAllUsesWith(op, newValues);
162126

163127
// Erase op and notify listener.
164128
eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
176140
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
177141
rewriteListener->notifyOperationReplaced(op, newOp);
178142

179-
// Replace results one-by-one. Also notifies the listener of modifications.
180-
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
181-
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
143+
// Replace all result uses. Also notifies the listener of modifications.
144+
replaceAllUsesWith(op, newOp->getResults());
182145

183146
// Erase op and notify listener.
184147
eraseOp(op);
@@ -276,15 +239,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
276239
rewriteListener->notifyOperationModified(op);
277240
}
278241

279-
/// Find uses of `from` and replace them with `to` if the `functor` returns
280-
/// true. It also marks every modified uses and notifies the rewriter that an
281-
/// in-place operation modification is about to happen.
282242
void RewriterBase::replaceUsesWithIf(Value from, Value to,
283-
function_ref<bool(OpOperand &)> functor) {
243+
function_ref<bool(OpOperand &)> functor,
244+
bool *allUsesReplaced) {
245+
bool allReplaced = true;
284246
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
285-
if (functor(operand))
247+
bool replace = functor(operand);
248+
if (replace)
286249
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
250+
allReplaced &= replace;
287251
}
252+
if (allUsesReplaced)
253+
*allUsesReplaced = allReplaced;
254+
}
255+
256+
void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
257+
function_ref<bool(OpOperand &)> functor,
258+
bool *allUsesReplaced) {
259+
assert(from.size() == to.size() && "incorrect number of replacements");
260+
bool allReplaced = true;
261+
for (auto it : llvm::zip_equal(from, to)) {
262+
bool r;
263+
replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
264+
/*allUsesReplaced=*/&r);
265+
allReplaced &= r;
266+
}
267+
if (allUsesReplaced)
268+
*allUsesReplaced = allReplaced;
288269
}
289270

290271
void RewriterBase::inlineBlockBefore(Block *source, Block *dest,

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,19 +1538,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
15381538

15391539
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
15401540

1541-
void ConversionPatternRewriter::replaceOpWithIf(
1542-
Operation *op, ValueRange newValues, bool *allUsesReplaced,
1543-
llvm::unique_function<bool(OpOperand &) const> functor) {
1544-
// TODO: To support this we will need to rework a bit of how replacements are
1545-
// tracked, given that this isn't guranteed to replace all of the uses of an
1546-
// operation. The main change is that now an operation can be replaced
1547-
// multiple times, in parts. The current "set" based tracking is mainly useful
1548-
// for tracking if a replaced operation should be ignored, i.e. if all of the
1549-
// uses will be replaced.
1550-
llvm_unreachable(
1551-
"replaceOpWithIf is currently not supported by DialectConversion");
1552-
}
1553-
15541541
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
15551542
assert(op && newOp && "expected non-null op");
15561543
replaceOp(op, newOp->getResults());

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
161161
rewriter.setInsertionPointToStart(newEntryBlock);
162162
for (auto *clonedOp : clonedOperations) {
163163
Operation *newOp = rewriter.clone(*clonedOp, map);
164-
rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
164+
rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
165165
}
166166
rewriter.mergeBlocks(
167167
entryBlock, newEntryBlock,

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1831,7 +1831,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
18311831
OperandRange operands = op.getOperands();
18321832

18331833
// Replace non-terminator uses with the first operand.
1834-
rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1834+
rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
18351835
return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
18361836
});
18371837
// Replace everything else with the second operand if the operation isn't

0 commit comments

Comments
 (0)