-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][IR] Add listener notifications for pattern begin/end #84131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Add listener notifications for pattern begin/end #84131
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds two new notifications to
The listener infrastructure already provides a This change is in preparation of improving the handle update mechanism in the Full diff: https://github.com/llvm/llvm-project/pull/84131.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..838b4947648f5e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationErased(Operation *op) {}
- /// Notify the listener that the pattern failed to match the given
- /// operation, and provide a callback to populate a diagnostic with the
- /// reason why the failure occurred. This method allows for derived
- /// listeners to optionally hook into the reason why a rewrite failed, and
- /// display it to users.
+ /// Notify the listener that the specified pattern is about to be applied
+ /// at the specified root operation.
+ virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
+
+ /// Notify the listener that a pattern application finished with the
+ /// specified status. "success" indicates that the pattern was applied
+ /// successfully. "failure" indicates that the pattern could not be
+ /// applied. The pattern may have communicated the reason for the failure
+ /// with `notifyMatchFailure`.
+ virtual void notifyPatternEnd(const Pattern &pattern,
+ LogicalResult status) {}
+
+ /// Notify the listener that the pattern failed to match, and provide a
+ /// callback to populate a diagnostic with the reason why the failure
+ /// occurred. This method allows for derived listeners to optionally hook
+ /// into the reason why a rewrite failed, and display it to users.
virtual void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyOperationErased(op);
}
+ void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyPatternBegin(pattern, op);
+ }
+ void notifyPatternEnd(const Pattern &pattern,
+ LogicalResult status) override {
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyPatternEnd(pattern, status);
+ }
void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5145246bc30c4..587fbe209b58af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
using LegalizationAction = ConversionTarget::LegalizationAction;
OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ const ConversionConfig &config);
/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
/// The pattern applicator to use for conversions.
PatternApplicator applicator;
+
+ /// Dialect conversion configuration.
+ const ConversionConfig &config;
};
} // namespace
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ const FrozenRewritePatternSet &patterns,
+ const ConversionConfig &config)
+ : target(targetInfo), applicator(patterns), config(config) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2105,7 +2110,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- return canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern, rewriter);
+ if (canApply && config.listener)
+ config.listener->notifyPatternBegin(pattern, op);
+ return canApply;
};
// Functor that cleans up the rewriter state after a pattern failed to match.
@@ -2122,6 +2130,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
rewriterImpl.config.notifyCallback(diag);
}
});
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, failure());
rewriterImpl.resetState(curState);
appliedPatterns.erase(&pattern);
};
@@ -2134,6 +2144,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
appliedPatterns.erase(&pattern);
if (failed(result))
rewriterImpl.resetState(curState);
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, result);
return result;
};
@@ -2509,7 +2521,8 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : opLegalizer(target, patterns), config(config), mode(mode) {}
+ : config(config), opLegalizer(target, patterns, this->config),
+ mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
@@ -2546,12 +2559,12 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping);
- /// The legalizer to use when converting operations.
- OperationLegalizer opLegalizer;
-
/// Dialect conversion configuration.
ConversionConfig config;
+ /// The legalizer to use when converting operations.
+ OperationLegalizer opLegalizer;
+
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
};
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 51d2f5e01b7235..5fda6f87196f94 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do
// here.
-#ifndef NDEBUG
- auto canApply = [&](const Pattern &pattern) {
- LLVM_DEBUG({
- logger.getOStream() << "\n";
- logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
- << op->getName() << " -> (";
- llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
- logger.getOStream() << ")' {\n";
- logger.indent();
- });
- return true;
- };
- auto onFailure = [&](const Pattern &pattern) {
- LLVM_DEBUG(logResult("failure", "pattern failed to match"));
- };
- auto onSuccess = [&](const Pattern &pattern) {
- LLVM_DEBUG(logResult("success", "pattern applied successfully"));
- return success();
- };
-#else
function_ref<bool(const Pattern &)> canApply = {};
function_ref<void(const Pattern &)> onFailure = {};
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
-#endif
+ bool debugBuild = false;
+#ifdef NDEBUG
+ debugBuild = true;
+#endif // NDEBUG
+ if (debugBuild || config.listener) {
+ canApply = [&](const Pattern &pattern) {
+ LLVM_DEBUG({
+ logger.getOStream() << "\n";
+ logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
+ << op->getName() << " -> (";
+ llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
+ logger.getOStream() << ")' {\n";
+ logger.indent();
+ });
+ if (config.listener)
+ config.listener->notifyPatternBegin(pattern, op);
+ return true;
+ };
+ onFailure = [&](const Pattern &pattern) {
+ LLVM_DEBUG(logResult("failure", "pattern failed to match"));
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, failure());
+ };
+ onSuccess = [&](const Pattern &pattern) {
+ LLVM_DEBUG(logResult("success", "pattern applied successfully"));
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, success());
+ return success();
+ };
+ }
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
@@ -731,7 +740,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
- logger.startLine() << "** Failure : " << diag.str() << "\n";
+ logger.startLine() << "** Match Failure : " << diag.str() << "\n";
});
if (config.listener)
config.listener->notifyMatchFailure(loc, reasonCallback);
|
Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion. This new functionality is hidden behind a `preserve_handles` attribute for now.
270ade8
to
c698363
Compare
407c7f7
to
0aef4b9
Compare
function_ref<LogicalResult(const Pattern &)> onSuccess = {}); | ||
std::function<bool(const Pattern &)> canApply = {}, | ||
std::function<void(const Pattern &)> onFailure = {}, | ||
std::function<LogicalResult(const Pattern &)> onSuccess = {}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a stack-use-after-scope
(reported by ASAN, also crashes in opt mode) with function_ref
. The callback in the greedy pattern rewriter is defined inside of an if
check:
function_ref<bool(const Pattern &)> canApply = {};
function_ref<void(const Pattern &)> onFailure = {};
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
bool debugBuild = false;
#ifdef NDEBUG
debugBuild = true;
#endif // NDEBUG
if (debugBuild || config.listener) {
// Lambda captures "this".
canApply = [&](const Pattern &pattern) {
if (this->config.listener) { /* ... */ }
// ...
}
// ...
}
// `canApply` points to a lambda that went out of scope.
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
function_ref
is a "non-owning function wrapper", but the lambda captures this
.
Changing to std::function
is one way to fix it. I could also just always pass a lambda. That would actually be my preferred solution, but there is a slight overhead when running in opt mode and without listener because the callback would always be called (even if it does not do anything):
LogicalResult matchResult = matcher.matchAndRewrite(
op, *this,
/*canApply=*/[&](const Pattern &pattern) {
if (this->listener) { /* ... */ }
// ...
},
/*onFailure=*/[&](const Pattern &pattern) { /* ... */},
/*onSuccess=*/[&](const Pattern &pattern) { /* ... */});
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(The existing code seemed to care about performance here; There were different canApply
etc. depending on NDEBUG
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That can explain why you changed it at the call-site, but I'm puzzled about this function: it does not capture the callback as far as I can tell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you referring to with this function
?
The problem here is really just caused by the fact that the canApply =
assignment is inside of a nested scope. And the lambda object is dead by the time matcher.matchAndRewrite
is called. I.e., the canApply
function_ref points to an already free'd lambda. At least that's my understanding.
What's the C++ guidelines wrt. to function
vs. function_ref
. This is the first time I ran into such an issue, and assigning lambdas to function_ref
feels "dangerous" to me now (because they can capture stuff). When using function
, I don't have to think about the lifetime of an object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you referring to with this function?
Where this comment thread is anchored: matchAndRewrite
The problem here is really just caused by the fact that the canApply = assignment is inside of a nested scope. And the lambda object is dead by the time matcher.matchAndRewrite is called. I.e., the canApply function_ref points to an already free'd lambda. At least that's my understanding.
Yes, but that's a problem for the call-site (processWorklist()
), I don't quite see where you make the connection to the signature of matchAndRewrite
?
What's the C++ guidelines wrt. to function vs. function_ref. This is the first time I ran into such an issue, and assigning lambdas to function_ref feels "dangerous" to me now (because they can capture stuff). When using function, I don't have to think about the lifetime of an object.
It is the same a StringRef
: in general you use them in the function API / signature. More rarely for local variable (usually lambda gets assigned an auto
variable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good way to think about it.
0aef4b9
to
24a56ca
Compare
onFailure = nullptr; | ||
onSuccess = nullptr; | ||
} | ||
#endif // NDEBUG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I didn't suggest changing this, what you had here was reasonable!
24a56ca
to
a65d640
Compare
This commit adds two new notifications to
RewriterBase::Listener
:notifyPatternBegin
: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion.notifyPatternEnd
: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion.The listener infrastructure already provides a
notifyMatchFailure
callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications.This change is in preparation of improving the handle update mechanism in the
apply_conversion_patterns
transform op.