Skip to content

[mlir][IR] Trigger notifyOperationReplaced on replaceAllOpUsesWith #84721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder {
/// Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationModified(Operation *op) {}

/// Notify the listener that the specified operation is about to be replaced
/// with another operation. This is called before the uses of the old
/// operation have been changed.
/// Notify the listener that all uses of the specified operation's results
/// are about to be replaced with the results of another operation. This is
/// called before the uses of the old operation have been changed.
///
/// By default, this function calls the "operation replaced with values"
/// notification.
Expand All @@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder {
notifyOperationReplaced(op, replacement->getResults());
}

/// Notify the listener that the specified operation is about to be replaced
/// with the a range of values, potentially produced by other operations.
/// This is called before the uses of the operation have been changed.
/// Notify the listener that all uses of the specified operation's results
/// are about to be replaced with the a range of values, potentially
/// produced by other operations. This is called before the uses of the
/// operation have been changed.
virtual void notifyOperationReplaced(Operation *op,
ValueRange replacement) {}

Expand Down Expand Up @@ -648,12 +649,16 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
// Note: This function cannot be called `replaceAllUsesWith` because the
// overload resolution, when called with an op that can be implicitly
// converted to a Value, would be ambiguous.
void replaceAllOpUsesWith(Operation *from, ValueRange to) {
replaceAllUsesWith(from->getResults(), to);
}

/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced)
/// and that the `from` operation is about to be replaced.
///
/// Note: This function cannot be called `replaceAllUsesWith` because the
/// overload resolution, when called with an op that can be implicitly
/// converted to a Value, would be ambiguous.
void replaceAllOpUsesWith(Operation *from, ValueRange to);
void replaceAllOpUsesWith(Operation *from, Operation *to);

/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. Also notify the listener about every in-place op modification (for
Expand Down
24 changes: 16 additions & 8 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,29 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}

void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(from, to);

replaceAllUsesWith(from->getResults(), to);
}

void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(from, to);

replaceAllUsesWith(from->getResults(), to->getResults());
}

/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");

// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);

// Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newValues);

Expand All @@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
assert(op->getNumResults() == newOp->getNumResults() &&
"ops have different number of results");

// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newOp);

// Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newOp->getResults());

Expand Down
5 changes: 4 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ struct TestStrictPatternDriver
OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
}
rewriter.replaceOp(op, newOp->getResults());
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
// A "notifyOperationReplaced" callback is triggered in either case.
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
rewriter.eraseOp(op);
return success();
}
};
Expand Down