Skip to content

Commit cd7af14

Browse files
committed
Fix canonicalizer to copy the entire GreedyRewriteConfig instead of selected fields
It is surprising for the user that only some fields were honored. Also make the FrozenRewritePatternSet a shared_ptr<const T>. Fixes llvm#64543 Differential Revision: https://reviews.llvm.org/D157469
1 parent f5b974b commit cd7af14

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -715,18 +715,20 @@ class AsmParser {
715715
//===--------------------------------------------------------------------===//
716716

717717
/// This class represents a StringSwitch like class that is useful for parsing
718-
/// expected keywords. On construction, it invokes `parseKeyword` and
719-
/// processes each of the provided cases statements until a match is hit. The
720-
/// provided `ResultT` must be assignable from `failure()`.
718+
/// expected keywords. On construction, unless a non-empty keyword is
719+
/// provided, it invokes `parseKeyword` and processes each of the provided
720+
/// cases statements until a match is hit. The provided `ResultT` must be
721+
/// assignable from `failure()`.
721722
template <typename ResultT = ParseResult>
722723
class KeywordSwitch {
723724
public:
724-
KeywordSwitch(AsmParser &parser)
725+
KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr)
725726
: parser(parser), loc(parser.getCurrentLocation()) {
726-
if (failed(parser.parseKeywordOrCompletion(&keyword)))
727+
if (keyword && !keyword->empty())
728+
this->keyword = *keyword;
729+
else if (failed(parser.parseKeywordOrCompletion(&this->keyword)))
727730
result = failure();
728731
}
729-
730732
/// Case that uses the provided value when true.
731733
KeywordSwitch &Case(StringLiteral str, ResultT value) {
732734
return Case(str, [&](StringRef, SMLoc) { return std::move(value); });

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
2929
Canonicalizer() = default;
3030
Canonicalizer(const GreedyRewriteConfig &config,
3131
ArrayRef<std::string> disabledPatterns,
32-
ArrayRef<std::string> enabledPatterns) {
32+
ArrayRef<std::string> enabledPatterns)
33+
: config(config) {
3334
this->topDownProcessingEnabled = config.useTopDownTraversal;
3435
this->enableRegionSimplification = config.enableRegionSimplification;
3536
this->maxIterations = config.maxIterations;
@@ -41,30 +42,31 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
4142
/// Initialize the canonicalizer by building the set of patterns used during
4243
/// execution.
4344
LogicalResult initialize(MLIRContext *context) override {
45+
// Set the config from possible pass options set in the meantime.
46+
config.useTopDownTraversal = topDownProcessingEnabled;
47+
config.enableRegionSimplification = enableRegionSimplification;
48+
config.maxIterations = maxIterations;
49+
config.maxNumRewrites = maxNumRewrites;
50+
4451
RewritePatternSet owningPatterns(context);
4552
for (auto *dialect : context->getLoadedDialects())
4653
dialect->getCanonicalizationPatterns(owningPatterns);
4754
for (RegisteredOperationName op : context->getRegisteredOperations())
4855
op.getCanonicalizationPatterns(owningPatterns, context);
4956

50-
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
51-
disabledPatterns, enabledPatterns);
57+
patterns = std::make_shared<FrozenRewritePatternSet>(
58+
std::move(owningPatterns), disabledPatterns, enabledPatterns);
5259
return success();
5360
}
5461
void runOnOperation() override {
55-
GreedyRewriteConfig config;
56-
config.useTopDownTraversal = topDownProcessingEnabled;
57-
config.enableRegionSimplification = enableRegionSimplification;
58-
config.maxIterations = maxIterations;
59-
config.maxNumRewrites = maxNumRewrites;
6062
LogicalResult converged =
61-
applyPatternsAndFoldGreedily(getOperation(), patterns, config);
63+
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
6264
// Canonicalization is best-effort. Non-convergence is not a pass failure.
6365
if (testConvergence && failed(converged))
6466
signalPassFailure();
6567
}
66-
67-
FrozenRewritePatternSet patterns;
68+
GreedyRewriteConfig config;
69+
std::shared_ptr<const FrozenRewritePatternSet> patterns;
6870
};
6971
} // namespace
7072

0 commit comments

Comments
 (0)