Skip to content

Commit e349fb7

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] NFC - Make markers use Identifier instead of StringRef
Summary: This removes string ownership worries by putting everything into the context and allows more constructing identifiers programmatically. Reviewers: ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul Tags: #mlir Differential Revision: https://reviews.llvm.org/D81027
1 parent c546825 commit e349fb7

File tree

5 files changed

+80
-67
lines changed

5 files changed

+80
-67
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1313
#include "mlir/Dialect/Vector/VectorOps.h"
14+
#include "mlir/IR/Identifier.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "llvm/ADT/SmallBitVector.h"
1617

@@ -206,15 +207,16 @@ struct LinalgTransforms {
206207

207208
/// Helper class to control common attribute matching and setting behavior.
208209
struct LinalgMarker {
209-
LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
210-
Optional<StringRef> replacement = None);
211-
LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
210+
explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
211+
Optional<Identifier> replacement = None);
212+
LinalgMarker(LinalgMarker &&) = default;
213+
LinalgMarker(const LinalgMarker &) = default;
212214
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
213215
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
214216

215217
private:
216-
SmallVector<StringRef, 4> matchDisjunction;
217-
Optional<StringRef> replacement;
218+
SmallVector<Identifier, 4> matchDisjunction;
219+
Optional<Identifier> replacement;
218220
};
219221

220222
///

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ class RewritePatternList<OpTy, OpTypes...> {
459459
public:
460460
static void insert(OwningRewritePatternList &patterns,
461461
const LinalgTilingOptions &options, MLIRContext *ctx) {
462-
patterns.insert<LinalgTilingPattern<OpTy>>(ctx, options,
463-
LinalgMarker({}, "tiled"));
462+
patterns.insert<LinalgTilingPattern<OpTy>>(
463+
ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
464464
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
465465
}
466466
};

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,11 @@ using llvm::dbgs;
4646
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
4747
"__internal_linalg_transform__";
4848

49-
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
50-
Optional<StringRef> replacement)
49+
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
50+
Optional<Identifier> replacement)
5151
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
5252
replacement(replacement) {}
5353

54-
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
55-
StringRef replacement)
56-
: LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
57-
5854
LogicalResult
5955
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
6056
Operation *op) const {
@@ -66,12 +62,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
6662
if (matchDisjunction.empty())
6763
return success();
6864

69-
// 2. Has no marker and matchDisjuntion matches the no-moarker case.
70-
for (auto marker : matchDisjunction)
71-
if (marker.empty())
72-
return success();
73-
74-
// 3. Has no marker but was expecting a marker.
65+
// 2. Has no marker but was expecting a marker.
7566
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
7667
diag << " does not have any marker from list: ";
7768
interleaveComma(matchDisjunction, diag);

mlir/test/Dialect/Linalg/transform-patterns.mlir

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
1515
%y: memref<?xf32, offset: ?, strides: [1]>,
1616
%v: memref<f32>) {
17-
linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
18-
memref<?xf32, offset: ?, strides: [1]>,
19-
memref<f32>
17+
linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
18+
memref<?xf32, offset: ?, strides: [1]>,
19+
memref<?xf32, offset: ?, strides: [1]>,
20+
memref<f32>
2021
return
2122
}
2223
// CHECK-LABEL: func @dot
@@ -35,9 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
3536
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
3637
%x: memref<?xf32, offset: ?, strides: [1]>,
3738
%y: memref<?xf32, offset: ?, strides: [1]>) {
38-
linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
39-
memref<?xf32, offset: ?, strides: [1]>,
40-
memref<?xf32, offset: ?, strides: [1]>
39+
linalg.matvec(%A, %x, %y) :
40+
memref<?x?xf32, offset: ?, strides: [?, 1]>,
41+
memref<?xf32, offset: ?, strides: [1]>,
42+
memref<?xf32, offset: ?, strides: [1]>
4143
return
4244
}
4345
// CHECK-LABEL: func @matvec
@@ -51,9 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
5153
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
5254
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
5355
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
54-
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
55-
memref<?x?xf32, offset: ?, strides: [?, 1]>,
56-
memref<?x?xf32, offset: ?, strides: [?, 1]>
56+
linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
57+
memref<?x?xf32, offset: ?, strides: [?, 1]>,
58+
memref<?x?xf32, offset: ?, strides: [?, 1]>,
59+
memref<?x?xf32, offset: ?, strides: [?, 1]>
5760
return
5861
}
5962
// CHECK-LABEL: func @matmul

mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
6666
//===--------------------------------------------------------------------===//
6767
patterns.insert<LinalgTilingPattern<MatmulOp>>(
6868
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
69-
LinalgMarker({"MEM", {}}, "L3"));
69+
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
7070
patterns.insert<LinalgTilingPattern<MatmulOp>>(
7171
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
72-
LinalgMarker({"L3"}, "L2"));
72+
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
7373
patterns.insert<LinalgTilingPattern<MatmulOp>>(
7474
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
75-
LinalgMarker({"L2"}, "L1"));
75+
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
7676
patterns.insert<LinalgTilingPattern<MatmulOp>>(
7777
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
78-
LinalgMarker({"L1"}, "REG"));
78+
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
7979

8080
patterns.insert<LinalgTilingPattern<MatvecOp>>(
8181
ctx,
8282
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
8383
LinalgTilingLoopType::ParallelLoops),
84-
LinalgMarker({}, "L1"));
84+
LinalgMarker({}, Identifier::get("L1", ctx)));
8585

8686
patterns.insert<LinalgTilingPattern<DotOp>>(
8787
ctx, LinalgTilingOptions().setTileSizes(8000),
88-
LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
88+
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
89+
Identifier::get("L3", ctx),
90+
Identifier::get("L2", ctx)},
91+
Identifier::get("REG", ctx)));
8992

9093
//===--------------------------------------------------------------------===//
9194
// Linalg tiling and permutation patterns.
@@ -95,75 +98,84 @@ static void applyPatterns(FuncOp funcOp) {
9598
LinalgTilingOptions()
9699
.setTileSizes({2000, 3000, 4000})
97100
.setInterchange({1, 2, 0}),
98-
LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
101+
LinalgMarker(Identifier::get("__with_perm__", ctx),
102+
Identifier::get("L2__with_perm__", ctx)));
99103
patterns.insert<LinalgTilingPattern<MatmulOp>>(
100104
ctx,
101105
LinalgTilingOptions()
102106
.setTileSizes({200, 300, 400})
103107
.setInterchange({1, 0, 2}),
104-
LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
108+
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
109+
Identifier::get("L1__with_perm__", ctx)));
105110
patterns.insert<LinalgTilingPattern<MatmulOp>>(
106111
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
107-
LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
112+
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
113+
Identifier::get("REG__with_perm__", ctx)));
108114

109115
patterns.insert<LinalgTilingPattern<MatvecOp>>(
110116
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
111-
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
117+
LinalgMarker(Identifier::get("__with_perm__", ctx),
118+
Identifier::get("L1__with_perm__", ctx)));
112119

113120
patterns.insert<LinalgTilingPattern<MatmulOp>>(
114121
ctx,
115122
LinalgTilingOptions()
116123
.setTileSizes({16, 8, 4})
117124
.setInterchange({1, 2, 0})
118125
.setLoopType(LinalgTilingLoopType::ParallelLoops),
119-
LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
126+
LinalgMarker(Identifier::get("par__with_perm__", ctx),
127+
Identifier::get("after_par__with_perm__", ctx)));
120128

121129
//===--------------------------------------------------------------------===//
122130
// Linalg to loops patterns.
123131
//===--------------------------------------------------------------------===//
124132
patterns.insert<LinalgLoweringPattern<DotOp>>(
125133
ctx,
126-
/*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
134+
/*loweringType=*/LinalgLoweringType::Loops,
135+
LinalgMarker(Identifier::get("REG", ctx)));
127136

128137
//===--------------------------------------------------------------------===//
129138
// Linalg to vector contraction patterns.
130139
//===--------------------------------------------------------------------===//
131140
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
132141
LinalgVectorizationPattern<FillOp>,
133142
LinalgVectorizationPattern<GenericOp>>(
134-
ctx, LinalgMarker({"VECTORIZE"}));
143+
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
135144

136145
//===--------------------------------------------------------------------===//
137146
// Linalg generic permutation patterns.
138147
//===--------------------------------------------------------------------===//
139148
patterns.insert<LinalgInterchangePattern<GenericOp>>(
140149
ctx,
141150
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
142-
LinalgMarker({}, "PERMUTED"));
151+
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
143152
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
144153
ctx,
145154
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
146-
LinalgMarker({}, "PERMUTED"));
155+
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
147156

148157
//===--------------------------------------------------------------------===//
149158
// Linalg subview operands promotion.
150159
//===--------------------------------------------------------------------===//
151160
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
152161
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
153-
LinalgMarker({"_promote_views_"}, "_views_promoted_"));
162+
LinalgMarker(Identifier::get("_promote_views_", ctx),
163+
Identifier::get("_views_promoted_", ctx)));
154164
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
155165
ctx,
156166
LinalgPromotionOptions()
157167
.setOperandsToPromote({0})
158168
.useFullTileBuffersByDefault(),
159-
LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
169+
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
170+
Identifier::get("_first_view_promoted_", ctx)));
160171
patterns.insert<LinalgPromotionPattern<FillOp>>(
161172
ctx,
162173
LinalgPromotionOptions()
163174
.setOperandsToPromote({0})
164175
.setUseFullTileBuffers({true})
165176
.setAlignment(32),
166-
LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
177+
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
178+
Identifier::get("_views_aligned_promoted_", ctx)));
167179

168180
applyPatternsAndFoldGreedily(funcOp, patterns);
169181

@@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
176188
static void fillL1TilingAndMatmulToVectorPatterns(
177189
FuncOp funcOp, StringRef startMarker,
178190
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
179-
MLIRContext *context = funcOp.getContext();
191+
MLIRContext *ctx = funcOp.getContext();
180192
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
181-
context,
193+
ctx,
182194
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
183-
LinalgMarker({startMarker}, "L1")));
195+
LinalgMarker(Identifier::get(startMarker, ctx),
196+
Identifier::get("L1", ctx))));
184197

185198
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
186-
context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
187-
LinalgMarker({"L1"}, "VEC")));
199+
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
200+
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
188201

189-
patternsVector.emplace_back(
190-
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
202+
patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
203+
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
191204
patternsVector.back()
192205
.insert<LinalgVectorizationPattern<FillOp>,
193-
LinalgVectorizationPattern<CopyOp>>(context);
206+
LinalgVectorizationPattern<CopyOp>>(ctx);
194207
}
195208

196209
//===----------------------------------------------------------------------===//
@@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
231244
return success();
232245
}
233246

234-
void fillPromotionCallBackPatterns(MLIRContext *context,
247+
void fillPromotionCallBackPatterns(MLIRContext *ctx,
235248
OwningRewritePatternList &patterns) {
236249
patterns.insert<LinalgTilingPattern<MatmulOp>>(
237-
context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
238-
LinalgMarker({"START"}, "PROMOTE"));
250+
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
251+
LinalgMarker(Identifier::get("START", ctx),
252+
Identifier::get("PROMOTE", ctx)));
239253
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
240-
context,
254+
ctx,
241255
LinalgPromotionOptions()
242256
.setOperandsToPromote({0, 2})
243257
.setUseFullTileBuffers({false, false})
@@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
251265
copyCallBackFn(b, src, dst, true);
252266
return success();
253267
}),
254-
LinalgMarker({"PROMOTE"}));
268+
LinalgMarker(Identifier::get("PROMOTE", ctx)));
255269
}
256270

257271
static void
@@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
261275
MLIRContext *ctx = funcOp.getContext();
262276
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
263277
if (testMatmulToVectorPatterns1dTiling) {
264-
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
278+
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
279+
stage1Patterns);
265280
} else if (testMatmulToVectorPatterns2dTiling) {
266-
stage1Patterns.emplace_back(
267-
LinalgTilingPattern<MatmulOp>(ctx,
268-
LinalgTilingOptions()
269-
.setTileSizes({768, 264, 768})
270-
.setInterchange({1, 2, 0}),
271-
LinalgMarker({"START"}, "L2")));
272-
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
281+
stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
282+
ctx,
283+
LinalgTilingOptions()
284+
.setTileSizes({768, 264, 768})
285+
.setInterchange({1, 2, 0}),
286+
LinalgMarker(Identifier::get("START", ctx),
287+
Identifier::get("L2", ctx))));
288+
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
289+
stage1Patterns);
273290
}
274291
OwningRewritePatternList stage2Patterns =
275292
getLinalgTilingCanonicalizationPatterns(ctx);

0 commit comments

Comments
 (0)