Skip to content

Commit 42ac4f3

Browse files
committed
[mlir] Canonicalizer constructor should accept disabled/enabled patterns
There is no way to programmatically configure the list of disabled and enabled patterns in the canonicalizer pass, other than the duplicate the whole pass. This patch exposes the `disabledPatterns` and `enabledPatterns` options. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D116055
1 parent 25226f3 commit 42ac4f3

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ class FrozenRewritePatternSet {
4040

4141
/// Freeze the patterns held in `patterns`, and take ownership.
4242
/// `disabledPatternLabels` is a set of labels used to filter out input
43-
/// patterns with a label in this set. `enabledPatternLabels` is a set of
44-
/// labels used to filter out input patterns that do not have one of the
45-
/// labels in this set.
43+
/// patterns with a debug label or debug name in this set.
44+
/// `enabledPatternLabels` is a set of labels used to filter out input
45+
/// patterns that do not have one of the labels in this set. Debug labels must
46+
/// be set explicitly on patterns or when adding them with
47+
/// `RewritePatternSet::addWithLabel`. Debug names may be empty, but patterns
48+
/// created with `RewritePattern::create` have their default debug name set to
49+
/// their type name.
4650
FrozenRewritePatternSet(
4751
RewritePatternSet &&patterns,
4852
ArrayRef<std::string> disabledPatternLabels = llvm::None,

mlir/include/mlir/Transforms/Passes.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,17 @@ std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
6262
std::unique_ptr<Pass> createCanonicalizerPass();
6363

6464
/// Creates an instance of the Canonicalizer pass with the specified config.
65+
/// `disabledPatterns` is a set of labels used to filter out input patterns with
66+
/// a debug label or debug name in this set. `enabledPatterns` is a set of
67+
/// labels used to filter out input patterns that do not have one of the labels
68+
/// in this set. Debug labels must be set explicitly on patterns or when adding
69+
/// them with `RewritePatternSet::addWithLabel`. Debug names may be empty, but
70+
/// patterns created with `RewritePattern::create` have their default debug name
71+
/// set to their type name.
6572
std::unique_ptr<Pass>
66-
createCanonicalizerPass(const GreedyRewriteConfig &config);
73+
createCanonicalizerPass(const GreedyRewriteConfig &config,
74+
ArrayRef<std::string> disabledPatterns = llvm::None,
75+
ArrayRef<std::string> enabledPatterns = llvm::None);
6776

6877
/// Creates a pass to perform common sub expression elimination.
6978
std::unique_ptr<Pass> createCSEPass();

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ using namespace mlir;
2121
namespace {
2222
/// Canonicalize operations in nested regions.
2323
struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
24-
Canonicalizer(const GreedyRewriteConfig &config) : config(config) {}
24+
Canonicalizer(const GreedyRewriteConfig &config,
25+
ArrayRef<std::string> disabledPatterns,
26+
ArrayRef<std::string> enabledPatterns)
27+
: config(config) {
28+
this->disabledPatterns = disabledPatterns;
29+
this->enabledPatterns = enabledPatterns;
30+
}
2531

2632
Canonicalizer() {
2733
// Default constructed Canonicalizer takes its settings from command line
@@ -61,6 +67,9 @@ std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
6167

6268
/// Creates an instance of the Canonicalizer pass with the specified config.
6369
std::unique_ptr<Pass>
64-
mlir::createCanonicalizerPass(const GreedyRewriteConfig &config) {
65-
return std::make_unique<Canonicalizer>(config);
70+
createCanonicalizerPass(const GreedyRewriteConfig &config,
71+
ArrayRef<std::string> disabledPatterns = llvm::None,
72+
ArrayRef<std::string> enabledPatterns = llvm::None) {
73+
return std::make_unique<Canonicalizer>(config, disabledPatterns,
74+
enabledPatterns);
6675
}

0 commit comments

Comments
 (0)