Skip to content

Commit 9a028af

Browse files
[mlir][IR][NFC] Listener::notifyMatchFailure returns void (llvm#80704)
There are two `notifyMatchFailure` methods: one in the rewriter and one in the listener. The one in the rewriter notifies the listener and returns "failure" for convenience. The one in the listener should not return anything; it is just a notification. It can currently be abused to return "success" from the rewriter function. That would be a violation of the rewriter API rules. Also make sure that the listener is always notified about match failures, not just with `NDEBUG`. The current implementation is consistent: one `notifyMatchFailure` overload notifies only in debug mode and another one notifies all the time.
1 parent 93962ea commit 9a028af

File tree

6 files changed

+19
-29
lines changed

6 files changed

+19
-29
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ class TrackingListener : public RewriterBase::Listener,
992992
/// Notify the listener that the pattern failed to match the given operation,
993993
/// and provide a callback to populate a diagnostic with the reason why the
994994
/// failure occurred.
995-
LogicalResult
995+
void
996996
notifyMatchFailure(Location loc,
997997
function_ref<void(Diagnostic &)> reasonCallback) override;
998998

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,9 @@ class RewriterBase : public OpBuilder {
437437
/// reason why the failure occurred. This method allows for derived
438438
/// listeners to optionally hook into the reason why a rewrite failed, and
439439
/// display it to users.
440-
virtual LogicalResult
440+
virtual void
441441
notifyMatchFailure(Location loc,
442-
function_ref<void(Diagnostic &)> reasonCallback) {
443-
return failure();
444-
}
442+
function_ref<void(Diagnostic &)> reasonCallback) {}
445443

446444
static bool classof(const OpBuilder::Listener *base);
447445
};
@@ -480,12 +478,11 @@ class RewriterBase : public OpBuilder {
480478
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
481479
rewriteListener->notifyOperationRemoved(op);
482480
}
483-
LogicalResult notifyMatchFailure(
481+
void notifyMatchFailure(
484482
Location loc,
485483
function_ref<void(Diagnostic &)> reasonCallback) override {
486484
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
487-
return rewriteListener->notifyMatchFailure(loc, reasonCallback);
488-
return failure();
485+
rewriteListener->notifyMatchFailure(loc, reasonCallback);
489486
}
490487

491488
private:
@@ -688,20 +685,16 @@ class RewriterBase : public OpBuilder {
688685
template <typename CallbackT>
689686
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
690687
notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
691-
#ifndef NDEBUG
692688
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
693-
return rewriteListener->notifyMatchFailure(
689+
rewriteListener->notifyMatchFailure(
694690
loc, function_ref<void(Diagnostic &)>(reasonCallback));
695691
return failure();
696-
#else
697-
return failure();
698-
#endif
699692
}
700693
template <typename CallbackT>
701694
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
702695
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
703696
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
704-
return rewriteListener->notifyMatchFailure(
697+
rewriteListener->notifyMatchFailure(
705698
op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
706699
return failure();
707700
}

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
740740
void cancelOpModification(Operation *op) override;
741741

742742
/// PatternRewriter hook for notifying match failure reasons.
743-
LogicalResult
743+
void
744744
notifyMatchFailure(Location loc,
745745
function_ref<void(Diagnostic &)> reasonCallback) override;
746746
using PatternRewriter::notifyMatchFailure;

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,14 +1265,13 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
12651265
return diag;
12661266
}
12671267

1268-
LogicalResult transform::TrackingListener::notifyMatchFailure(
1268+
void transform::TrackingListener::notifyMatchFailure(
12691269
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
12701270
LLVM_DEBUG({
12711271
Diagnostic diag(loc, DiagnosticSeverity::Remark);
12721272
reasonCallback(diag);
12731273
DBGS() << "Match Failure : " << diag.str() << "\n";
12741274
});
1275-
return failure();
12761275
}
12771276

12781277
void transform::TrackingListener::notifyOperationRemoved(Operation *op) {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -921,9 +921,8 @@ struct ConversionPatternRewriterImpl {
921921
Block::iterator before);
922922

923923
/// Notifies that a pattern match failed for the given reason.
924-
LogicalResult
925-
notifyMatchFailure(Location loc,
926-
function_ref<void(Diagnostic &)> reasonCallback);
924+
void notifyMatchFailure(Location loc,
925+
function_ref<void(Diagnostic &)> reasonCallback);
927926

928927
//===--------------------------------------------------------------------===//
929928
// State
@@ -1236,10 +1235,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12361235
legalTypes.clear();
12371236
if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
12381237
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1239-
return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1238+
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
12401239
diag << "unable to convert type for " << valueDiagTag << " #"
12411240
<< it.index() << ", type was " << origType;
12421241
});
1242+
return failure();
12431243
}
12441244
// TODO: There currently isn't any mechanism to do 1->N type conversion
12451245
// via the PatternRewriter replacement API, so for now we just ignore it.
@@ -1419,7 +1419,7 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
14191419
blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
14201420
}
14211421

1422-
LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1422+
void ConversionPatternRewriterImpl::notifyMatchFailure(
14231423
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
14241424
LLVM_DEBUG({
14251425
Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -1428,7 +1428,6 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
14281428
if (notifyCallback)
14291429
notifyCallback(diag);
14301430
});
1431-
return failure();
14321431
}
14331432

14341433
//===----------------------------------------------------------------------===//
@@ -1615,9 +1614,9 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
16151614
rootUpdates.erase(rootUpdates.begin() + updateIdx);
16161615
}
16171616

1618-
LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1617+
void ConversionPatternRewriter::notifyMatchFailure(
16191618
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1620-
return impl->notifyMatchFailure(loc, reasonCallback);
1619+
impl->notifyMatchFailure(loc, reasonCallback);
16211620
}
16221621

16231622
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
387387
void notifyBlockRemoved(Block *block) override;
388388

389389
/// For debugging only: Notify the driver of a pattern match failure.
390-
LogicalResult
390+
void
391391
notifyMatchFailure(Location loc,
392392
function_ref<void(Diagnostic &)> reasonCallback) override;
393393

@@ -726,16 +726,15 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
726726
config.listener->notifyOperationReplaced(op, replacement);
727727
}
728728

729-
LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
729+
void GreedyPatternRewriteDriver::notifyMatchFailure(
730730
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
731731
LLVM_DEBUG({
732732
Diagnostic diag(loc, DiagnosticSeverity::Remark);
733733
reasonCallback(diag);
734734
logger.startLine() << "** Failure : " << diag.str() << "\n";
735735
});
736736
if (config.listener)
737-
return config.listener->notifyMatchFailure(loc, reasonCallback);
738-
return failure();
737+
config.listener->notifyMatchFailure(loc, reasonCallback);
739738
}
740739

741740
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)