Skip to content

Commit 1e98d48

Browse files
authored
[mlir][linalg] NFC: Use tablegen macro for pass constructors (llvm#82892)
This uses the tablegen macros for generating pass constructors, exposing pass options for fold-unit-extent-dims and linalg-detensorize. Additionally aligns some of the pass namings to their text counterpart. This includes an API change: createLinalgGeneralizationPass -> createLinalgGeneralizeNamedOpsPass
1 parent 8e22fff commit 1e98d48

File tree

12 files changed

+89
-158
lines changed

12 files changed

+89
-158
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,7 @@ struct OneShotBufferizationOptions;
2727
} // namespace bufferization
2828

2929
#define GEN_PASS_DECL
30-
#include "mlir/Dialect/Linalg/Passes.h.inc"
31-
32-
std::unique_ptr<Pass> createConvertElementwiseToLinalgPass();
33-
34-
std::unique_ptr<Pass> createLinalgFoldUnitExtentDimsPass();
35-
36-
std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();
37-
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
38-
39-
std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
40-
41-
std::unique_ptr<Pass> createLinalgInlineScalarOperandsPass();
42-
43-
/// Create a pass to convert Linalg operations to scf.for loops and
44-
/// memref.load/memref.store accesses.
45-
std::unique_ptr<Pass> createConvertLinalgToLoopsPass();
46-
47-
/// Create a pass to convert Linalg operations to scf.parallel loops and
48-
/// memref.load/memref.store accesses.
49-
std::unique_ptr<Pass> createConvertLinalgToParallelLoopsPass();
50-
51-
/// Create a pass to convert Linalg operations to affine.for loops and
52-
/// affine_load/affine_store accesses.
53-
/// Placeholder for now, this is NYI.
54-
std::unique_ptr<Pass> createConvertLinalgToAffineLoopsPass();
55-
56-
/// Create a pass to convert Linalg operations which work on tensors to use
57-
/// buffers instead.
58-
std::unique_ptr<Pass> createLinalgBufferizePass();
59-
60-
/// Create a pass to convert named Linalg operations to Linalg generic
61-
/// operations.
62-
std::unique_ptr<Pass> createLinalgGeneralizationPass();
63-
64-
/// Create a pass to convert Linalg operations to equivalent operations that
65-
/// work on primitive types, if possible.
66-
std::unique_ptr<Pass> createLinalgDetensorizePass();
30+
#include "mlir/Dialect/Linalg/Passes.h.inc" // IWYU pragma: keep
6731

6832
//===----------------------------------------------------------------------===//
6933
// Registration

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
14+
def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
1515
let summary = "Convert ElementwiseMappable ops to linalg";
1616
let description = [{
1717
Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
@@ -20,54 +20,17 @@ def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
2020
run on op which contains linalg ops (most commonly a
2121
FunctionOpInterface op).
2222
}];
23-
let constructor = "mlir::createConvertElementwiseToLinalgPass()";
2423
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
2524
}
2625

27-
def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
28-
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
29-
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
30-
let options = [
31-
Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
32-
/*default=*/"false",
33-
"Generate rank-reducing slices instead of reassociative reshapes">
34-
];
35-
let dependentDialects = [
36-
"linalg::LinalgDialect", "affine::AffineDialect", "memref::MemRefDialect"
37-
];
38-
}
39-
40-
def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> {
41-
let summary = "Fuse elementwise operations on tensors";
42-
let constructor = "mlir::createLinalgElementwiseOpFusionPass()";
43-
let dependentDialects = [
44-
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
45-
];
46-
}
47-
48-
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
49-
let summary = "Convert from one named linalg op to another.";
50-
let constructor = "mlir::createLinalgNamedOpConversionPass()";
51-
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
52-
}
53-
54-
def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands"> {
55-
let summary = "Inline scalar operands into linalg generic ops";
56-
let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
57-
let dependentDialects = [
58-
"linalg::LinalgDialect"
59-
];
60-
}
61-
62-
def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops"> {
26+
def ConvertLinalgToAffineLoopsPass : Pass<"convert-linalg-to-affine-loops"> {
6327
let summary = "Lower the operations from the linalg dialect into affine "
6428
"loops";
65-
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
6629
let dependentDialects = [
6730
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
6831
}
6932

70-
def LinalgLowerToLoops : Pass<"convert-linalg-to-loops"> {
33+
def ConvertLinalgToLoopsPass : Pass<"convert-linalg-to-loops"> {
7134
let summary = "Lower the operations from the linalg dialect into loops";
7235
let description = [{
7336
Lowers the `linalg` ops to loop nests using `scf.for`.
@@ -76,19 +39,17 @@ def LinalgLowerToLoops : Pass<"convert-linalg-to-loops"> {
7639
i.e., tensor operands and results must be converted to memrefs via
7740
bufferization.
7841
}];
79-
let constructor = "mlir::createConvertLinalgToLoopsPass()";
8042
let dependentDialects = [
8143
"linalg::LinalgDialect",
8244
"scf::SCFDialect",
8345
"affine::AffineDialect"
8446
];
8547
}
8648

87-
def LinalgLowerToParallelLoops
49+
def ConvertLinalgToParallelLoopsPass
8850
: Pass<"convert-linalg-to-parallel-loops"> {
8951
let summary = "Lower the operations from the linalg dialect into parallel "
9052
"loops";
91-
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
9253
let dependentDialects = [
9354
"affine::AffineDialect",
9455
"linalg::LinalgDialect",
@@ -97,9 +58,39 @@ def LinalgLowerToParallelLoops
9758
];
9859
}
9960

100-
def LinalgBufferize : Pass<"linalg-bufferize"> {
61+
def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
62+
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
63+
let options = [
64+
Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
65+
/*default=*/"false",
66+
"Generate rank-reducing slices instead of reassociative reshapes">
67+
];
68+
let dependentDialects = [
69+
"linalg::LinalgDialect", "affine::AffineDialect", "memref::MemRefDialect"
70+
];
71+
}
72+
73+
def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
74+
let summary = "Fuse elementwise operations on tensors";
75+
let dependentDialects = [
76+
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
77+
];
78+
}
79+
80+
def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
81+
let summary = "Convert from one named linalg op to another.";
82+
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
83+
}
84+
85+
def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
86+
let summary = "Inline scalar operands into linalg generic ops";
87+
let dependentDialects = [
88+
"linalg::LinalgDialect"
89+
];
90+
}
91+
92+
def LinalgBufferizePass : Pass<"linalg-bufferize"> {
10193
let summary = "Bufferize the linalg dialect";
102-
let constructor = "mlir::createLinalgBufferizePass()";
10394
let dependentDialects = [
10495
"affine::AffineDialect",
10596
"bufferization::BufferizationDialect",
@@ -108,15 +99,13 @@ def LinalgBufferize : Pass<"linalg-bufferize"> {
10899
];
109100
}
110101

111-
def LinalgGeneralization : Pass<"linalg-generalize-named-ops"> {
102+
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
112103
let summary = "Convert named ops into generic ops";
113-
let constructor = "mlir::createLinalgGeneralizationPass()";
114104
let dependentDialects = ["linalg::LinalgDialect"];
115105
}
116106

117-
def LinalgDetensorize : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
107+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
118108
let summary = "Detensorize linalg ops";
119-
let constructor = "mlir::createLinalgDetensorizePass()";
120109
let dependentDialects = [];
121110

122111
let description = [{

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "mlir/Pass/Pass.h"
2222

2323
namespace mlir {
24-
#define GEN_PASS_DEF_LINALGBUFFERIZE
24+
#define GEN_PASS_DEF_LINALGBUFFERIZEPASS
2525
#include "mlir/Dialect/Linalg/Passes.h.inc"
2626
} // namespace mlir
2727

@@ -32,7 +32,9 @@ namespace {
3232
/// Converts Linalg operations that work on tensor-type operands or results to
3333
/// work on buffers.
3434
struct LinalgBufferizePass
35-
: public impl::LinalgBufferizeBase<LinalgBufferizePass> {
35+
: public impl::LinalgBufferizePassBase<LinalgBufferizePass> {
36+
using impl::LinalgBufferizePassBase<
37+
LinalgBufferizePass>::LinalgBufferizePassBase;
3638
void runOnOperation() override {
3739
BufferizationOptions options = getPartialBufferizationOptions();
3840
options.opFilter.allowDialect<linalg::LinalgDialect>();
@@ -48,7 +50,3 @@ struct LinalgBufferizePass
4850
}
4951
};
5052
} // namespace
51-
52-
std::unique_ptr<Pass> mlir::createLinalgBufferizePass() {
53-
return std::make_unique<LinalgBufferizePass>();
54-
}

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include <utility>
2222

2323
namespace mlir {
24-
#define GEN_PASS_DEF_LINALGDETENSORIZE
24+
#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
2525
#include "mlir/Dialect/Linalg/Passes.h.inc"
2626
} // namespace mlir
2727

@@ -164,7 +164,9 @@ class DetensorizeTypeConverter : public TypeConverter {
164164

165165
/// @see LinalgDetensorize in Linalg/Passes.td for more details.
166166
struct LinalgDetensorize
167-
: public impl::LinalgDetensorizeBase<LinalgDetensorize> {
167+
: public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
168+
using impl::LinalgDetensorizePassBase<
169+
LinalgDetensorize>::LinalgDetensorizePassBase;
168170
LinalgDetensorize() = default;
169171

170172
class CostModel {
@@ -576,7 +578,3 @@ struct LinalgDetensorize
576578
}
577579
};
578580
} // namespace
579-
580-
std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
581-
return std::make_unique<LinalgDetensorize>();
582-
}

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "llvm/Support/Debug.h"
3434

3535
namespace mlir {
36-
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
36+
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
3737
#include "mlir/Dialect/Linalg/Passes.h.inc"
3838
} // namespace mlir
3939

@@ -689,7 +689,10 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
689689
namespace {
690690
/// Pass that removes unit-extent dims within generic ops.
691691
struct LinalgFoldUnitExtentDimsPass
692-
: public impl::LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
692+
: public impl::LinalgFoldUnitExtentDimsPassBase<
693+
LinalgFoldUnitExtentDimsPass> {
694+
using impl::LinalgFoldUnitExtentDimsPassBase<
695+
LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
693696
void runOnOperation() override {
694697
Operation *op = getOperation();
695698
MLIRContext *context = op->getContext();
@@ -705,7 +708,3 @@ struct LinalgFoldUnitExtentDimsPass
705708
}
706709
};
707710
} // namespace
708-
709-
std::unique_ptr<Pass> mlir::createLinalgFoldUnitExtentDimsPass() {
710-
return std::make_unique<LinalgFoldUnitExtentDimsPass>();
711-
}

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
#include <utility>
2828

2929
namespace mlir {
30-
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
31-
#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSION
30+
#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
3231
#include "mlir/Dialect/Linalg/Passes.h.inc"
3332
} // namespace mlir
3433

@@ -1927,8 +1926,10 @@ namespace {
19271926
// favor of test passes that check the functionality of each of the patterns
19281927
// added here individually.
19291928
struct LinalgElementwiseOpFusionPass
1930-
: public impl::LinalgElementwiseOpFusionBase<
1929+
: public impl::LinalgElementwiseOpFusionPassBase<
19311930
LinalgElementwiseOpFusionPass> {
1931+
using impl::LinalgElementwiseOpFusionPassBase<
1932+
LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
19321933
void runOnOperation() override {
19331934
Operation *op = getOperation();
19341935
MLIRContext *context = op->getContext();
@@ -1963,7 +1964,3 @@ struct LinalgElementwiseOpFusionPass
19631964
};
19641965

19651966
} // namespace
1966-
1967-
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1968-
return std::make_unique<LinalgElementwiseOpFusionPass>();
1969-
}

mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "mlir/Transforms/DialectConversion.h"
1616

1717
namespace mlir {
18-
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALG
18+
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
1919
#include "mlir/Dialect/Linalg/Passes.h.inc"
2020
} // namespace mlir
2121

@@ -121,8 +121,10 @@ void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
121121

122122
namespace {
123123
class ConvertElementwiseToLinalgPass
124-
: public impl::ConvertElementwiseToLinalgBase<
124+
: public impl::ConvertElementwiseToLinalgPassBase<
125125
ConvertElementwiseToLinalgPass> {
126+
using impl::ConvertElementwiseToLinalgPassBase<
127+
ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
126128

127129
void runOnOperation() final {
128130
auto *func = getOperation();
@@ -140,7 +142,3 @@ class ConvertElementwiseToLinalgPass
140142
}
141143
};
142144
} // namespace
143-
144-
std::unique_ptr<Pass> mlir::createConvertElementwiseToLinalgPass() {
145-
return std::make_unique<ConvertElementwiseToLinalgPass>();
146-
}

mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include "llvm/Support/Debug.h"
2727

2828
namespace mlir {
29-
#define GEN_PASS_DEF_LINALGGENERALIZATION
29+
#define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
3030
#include "mlir/Dialect/Linalg/Passes.h.inc"
3131
} // namespace mlir
3232

@@ -76,14 +76,17 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
7676

7777
namespace {
7878

79-
struct LinalgGeneralizationPass
80-
: public impl::LinalgGeneralizationBase<LinalgGeneralizationPass> {
79+
struct LinalgGeneralizeNamedOpsPass
80+
: public impl::LinalgGeneralizeNamedOpsPassBase<
81+
LinalgGeneralizeNamedOpsPass> {
82+
using impl::LinalgGeneralizeNamedOpsPassBase<
83+
LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
8184
void runOnOperation() override;
8285
};
8386

8487
} // namespace
8588

86-
void LinalgGeneralizationPass::runOnOperation() {
89+
void LinalgGeneralizeNamedOpsPass::runOnOperation() {
8790
RewritePatternSet patterns(&getContext());
8891
populateLinalgNamedOpsGeneralizationPatterns(patterns);
8992
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
@@ -93,7 +96,3 @@ void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
9396
RewritePatternSet &patterns) {
9497
patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
9598
}
96-
97-
std::unique_ptr<Pass> mlir::createLinalgGeneralizationPass() {
98-
return std::make_unique<LinalgGeneralizationPass>();
99-
}

0 commit comments

Comments
 (0)