Skip to content

Commit 38113a0

Browse files
[mlir][IR] Trigger notifyOperationReplaced on replaceAllOpUsesWith (#84721)
Before this change: `notifyOperationReplaced` was triggered when calling `RewriteBase::replaceOp`. After this change: `notifyOperationReplaced` is triggered when `RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is called. Until now, every `notifyOperationReplaced` was always sent together with a `notifyOperationErased`, which made that `notifyOperationErased` callback irrelevant. More importantly, when a user called `RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of `RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because the `TrackingListener` will only see the `notifyOperationErased` callback and the payload op is dropped from the mappings. Note: It is still possible to write semantically equivalent code that does not trigger a `notifyOperationReplaced` (e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.
1 parent d7a43a0 commit 38113a0

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder {
409409
/// Notify the listener that the specified operation was modified in-place.
410410
virtual void notifyOperationModified(Operation *op) {}
411411

412-
/// Notify the listener that the specified operation is about to be replaced
413-
/// with another operation. This is called before the uses of the old
414-
/// operation have been changed.
412+
/// Notify the listener that all uses of the specified operation's results
413+
/// are about to be replaced with the results of another operation. This is
414+
/// called before the uses of the old operation have been changed.
415415
///
416416
/// By default, this function calls the "operation replaced with values"
417417
/// notification.
@@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder {
420420
notifyOperationReplaced(op, replacement->getResults());
421421
}
422422

423-
/// Notify the listener that the specified operation is about to be replaced
424-
/// with the a range of values, potentially produced by other operations.
425-
/// This is called before the uses of the operation have been changed.
423+
/// Notify the listener that all uses of the specified operation's results
424+
/// are about to be replaced with the a range of values, potentially
425+
/// produced by other operations. This is called before the uses of the
426+
/// operation have been changed.
426427
virtual void notifyOperationReplaced(Operation *op,
427428
ValueRange replacement) {}
428429

@@ -648,12 +649,16 @@ class RewriterBase : public OpBuilder {
648649
for (auto it : llvm::zip(from, to))
649650
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
650651
}
651-
// Note: This function cannot be called `replaceAllUsesWith` because the
652-
// overload resolution, when called with an op that can be implicitly
653-
// converted to a Value, would be ambiguous.
654-
void replaceAllOpUsesWith(Operation *from, ValueRange to) {
655-
replaceAllUsesWith(from->getResults(), to);
656-
}
652+
653+
/// Find uses of `from` and replace them with `to`. Also notify the listener
654+
/// about every in-place op modification (for every use that was replaced)
655+
/// and that the `from` operation is about to be replaced.
656+
///
657+
/// Note: This function cannot be called `replaceAllUsesWith` because the
658+
/// overload resolution, when called with an op that can be implicitly
659+
/// converted to a Value, would be ambiguous.
660+
void replaceAllOpUsesWith(Operation *from, ValueRange to);
661+
void replaceAllOpUsesWith(Operation *from, Operation *to);
657662

658663
/// Find uses of `from` and replace them with `to` if the `functor` returns
659664
/// true. Also notify the listener about every in-place op modification (for

mlir/lib/IR/PatternMatch.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,29 @@ RewriterBase::~RewriterBase() {
110110
// Out of line to provide a vtable anchor for the class.
111111
}
112112

113+
void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
114+
// Notify the listener that we're about to replace this op.
115+
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
116+
rewriteListener->notifyOperationReplaced(from, to);
117+
118+
replaceAllUsesWith(from->getResults(), to);
119+
}
120+
121+
void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
122+
// Notify the listener that we're about to replace this op.
123+
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
124+
rewriteListener->notifyOperationReplaced(from, to);
125+
126+
replaceAllUsesWith(from->getResults(), to->getResults());
127+
}
128+
113129
/// This method replaces the results of the operation with the specified list of
114130
/// values. The number of provided values must match the number of results of
115131
/// the operation. The replaced op is erased.
116132
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
117133
assert(op->getNumResults() == newValues.size() &&
118134
"incorrect # of replacement values");
119135

120-
// Notify the listener that we're about to replace this op.
121-
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
122-
rewriteListener->notifyOperationReplaced(op, newValues);
123-
124136
// Replace all result uses. Also notifies the listener of modifications.
125137
replaceAllOpUsesWith(op, newValues);
126138

@@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
136148
assert(op->getNumResults() == newOp->getNumResults() &&
137149
"ops have different number of results");
138150

139-
// Notify the listener that we're about to replace this op.
140-
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
141-
rewriteListener->notifyOperationReplaced(op, newOp);
142-
143151
// Replace all result uses. Also notifies the listener of modifications.
144152
replaceAllOpUsesWith(op, newOp->getResults());
145153

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,10 @@ struct TestStrictPatternDriver
489489
OperationName("test.new_op", op->getContext()).getIdentifier(),
490490
op->getOperands(), op->getResultTypes());
491491
}
492-
rewriter.replaceOp(op, newOp->getResults());
492+
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
493+
// A "notifyOperationReplaced" callback is triggered in either case.
494+
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
495+
rewriter.eraseOp(op);
493496
return success();
494497
}
495498
};

0 commit comments

Comments
 (0)