@@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
66
66
// ===--------------------------------------------------------------------===//
67
67
patterns.insert <LinalgTilingPattern<MatmulOp>>(
68
68
ctx, LinalgTilingOptions ().setTileSizes ({2000 , 3000 , 4000 }),
69
- LinalgMarker ({ " MEM" , {}}, " L3" ));
69
+ LinalgMarker (Identifier::get ( " MEM" , ctx), Identifier::get ( " L3" , ctx) ));
70
70
patterns.insert <LinalgTilingPattern<MatmulOp>>(
71
71
ctx, LinalgTilingOptions ().setTileSizes ({200 , 300 , 400 }),
72
- LinalgMarker ({ " L3" }, " L2" ));
72
+ LinalgMarker (Identifier::get ( " L3" , ctx), Identifier::get ( " L2" , ctx) ));
73
73
patterns.insert <LinalgTilingPattern<MatmulOp>>(
74
74
ctx, LinalgTilingOptions ().setTileSizes ({20 , 30 , 40 }),
75
- LinalgMarker ({ " L2" }, " L1" ));
75
+ LinalgMarker (Identifier::get ( " L2" , ctx), Identifier::get ( " L1" , ctx) ));
76
76
patterns.insert <LinalgTilingPattern<MatmulOp>>(
77
77
ctx, LinalgTilingOptions ().setTileSizes ({2 , 3 , 4 }),
78
- LinalgMarker ({ " L1" }, " REG" ));
78
+ LinalgMarker (Identifier::get ( " L1" , ctx), Identifier::get ( " REG" , ctx) ));
79
79
80
80
patterns.insert <LinalgTilingPattern<MatvecOp>>(
81
81
ctx,
82
82
LinalgTilingOptions ().setTileSizes ({5 , 6 }).setLoopType (
83
83
LinalgTilingLoopType::ParallelLoops),
84
- LinalgMarker ({}, " L1" ));
84
+ LinalgMarker ({}, Identifier::get ( " L1" , ctx) ));
85
85
86
86
patterns.insert <LinalgTilingPattern<DotOp>>(
87
87
ctx, LinalgTilingOptions ().setTileSizes (8000 ),
88
- LinalgMarker ({" MEM" , " L3" , " L2" , {}}, " REG" ));
88
+ LinalgMarker (ArrayRef<Identifier>{Identifier::get (" MEM" , ctx),
89
+ Identifier::get (" L3" , ctx),
90
+ Identifier::get (" L2" , ctx)},
91
+ Identifier::get (" REG" , ctx)));
89
92
90
93
// ===--------------------------------------------------------------------===//
91
94
// Linalg tiling and permutation patterns.
@@ -95,75 +98,84 @@ static void applyPatterns(FuncOp funcOp) {
95
98
LinalgTilingOptions ()
96
99
.setTileSizes ({2000 , 3000 , 4000 })
97
100
.setInterchange ({1 , 2 , 0 }),
98
- LinalgMarker ({" __with_perm__" }, " L2__with_perm__" ));
101
+ LinalgMarker (Identifier::get (" __with_perm__" , ctx),
102
+ Identifier::get (" L2__with_perm__" , ctx)));
99
103
patterns.insert <LinalgTilingPattern<MatmulOp>>(
100
104
ctx,
101
105
LinalgTilingOptions ()
102
106
.setTileSizes ({200 , 300 , 400 })
103
107
.setInterchange ({1 , 0 , 2 }),
104
- LinalgMarker ({" L2__with_perm__" }, " L1__with_perm__" ));
108
+ LinalgMarker (Identifier::get (" L2__with_perm__" , ctx),
109
+ Identifier::get (" L1__with_perm__" , ctx)));
105
110
patterns.insert <LinalgTilingPattern<MatmulOp>>(
106
111
ctx, LinalgTilingOptions ().setTileSizes ({20 , 30 , 40 }),
107
- LinalgMarker ({" L1__with_perm__" }, " REG__with_perm__" ));
112
+ LinalgMarker (Identifier::get (" L1__with_perm__" , ctx),
113
+ Identifier::get (" REG__with_perm__" , ctx)));
108
114
109
115
patterns.insert <LinalgTilingPattern<MatvecOp>>(
110
116
ctx, LinalgTilingOptions ().setTileSizes ({5 , 6 }).setInterchange ({1 , 0 }),
111
- LinalgMarker ({" __with_perm__" }, " L1__with_perm__" ));
117
+ LinalgMarker (Identifier::get (" __with_perm__" , ctx),
118
+ Identifier::get (" L1__with_perm__" , ctx)));
112
119
113
120
patterns.insert <LinalgTilingPattern<MatmulOp>>(
114
121
ctx,
115
122
LinalgTilingOptions ()
116
123
.setTileSizes ({16 , 8 , 4 })
117
124
.setInterchange ({1 , 2 , 0 })
118
125
.setLoopType (LinalgTilingLoopType::ParallelLoops),
119
- LinalgMarker ({" par__with_perm__" }, " after_par__with_perm__" ));
126
+ LinalgMarker (Identifier::get (" par__with_perm__" , ctx),
127
+ Identifier::get (" after_par__with_perm__" , ctx)));
120
128
121
129
// ===--------------------------------------------------------------------===//
122
130
// Linalg to loops patterns.
123
131
// ===--------------------------------------------------------------------===//
124
132
patterns.insert <LinalgLoweringPattern<DotOp>>(
125
133
ctx,
126
- /* loweringType=*/ LinalgLoweringType::Loops, LinalgMarker ({" REG" }));
134
+ /* loweringType=*/ LinalgLoweringType::Loops,
135
+ LinalgMarker (Identifier::get (" REG" , ctx)));
127
136
128
137
// ===--------------------------------------------------------------------===//
129
138
// Linalg to vector contraction patterns.
130
139
// ===--------------------------------------------------------------------===//
131
140
patterns.insert <LinalgVectorizationPattern<MatmulOp>,
132
141
LinalgVectorizationPattern<FillOp>,
133
142
LinalgVectorizationPattern<GenericOp>>(
134
- ctx, LinalgMarker ({ " VECTORIZE" } ));
143
+ ctx, LinalgMarker (Identifier::get ( " VECTORIZE" , ctx) ));
135
144
136
145
// ===--------------------------------------------------------------------===//
137
146
// Linalg generic permutation patterns.
138
147
// ===--------------------------------------------------------------------===//
139
148
patterns.insert <LinalgInterchangePattern<GenericOp>>(
140
149
ctx,
141
150
/* interchangeVector=*/ ArrayRef<unsigned >{1 , 2 , 0 },
142
- LinalgMarker ({}, " PERMUTED" ));
151
+ LinalgMarker ({}, Identifier::get ( " PERMUTED" , ctx) ));
143
152
patterns.insert <LinalgInterchangePattern<IndexedGenericOp>>(
144
153
ctx,
145
154
/* interchangeVector=*/ ArrayRef<unsigned >{1 , 2 , 0 },
146
- LinalgMarker ({}, " PERMUTED" ));
155
+ LinalgMarker ({}, Identifier::get ( " PERMUTED" , ctx) ));
147
156
148
157
// ===--------------------------------------------------------------------===//
149
158
// Linalg subview operands promotion.
150
159
// ===--------------------------------------------------------------------===//
151
160
patterns.insert <LinalgPromotionPattern<MatmulOp>>(
152
161
ctx, LinalgPromotionOptions ().useFullTileBuffersByDefault (),
153
- LinalgMarker ({" _promote_views_" }, " _views_promoted_" ));
162
+ LinalgMarker (Identifier::get (" _promote_views_" , ctx),
163
+ Identifier::get (" _views_promoted_" , ctx)));
154
164
patterns.insert <LinalgPromotionPattern<MatmulOp>>(
155
165
ctx,
156
166
LinalgPromotionOptions ()
157
167
.setOperandsToPromote ({0 })
158
168
.useFullTileBuffersByDefault (),
159
- LinalgMarker ({" _promote_first_view_" }, " _first_view_promoted_" ));
169
+ LinalgMarker (Identifier::get (" _promote_first_view_" , ctx),
170
+ Identifier::get (" _first_view_promoted_" , ctx)));
160
171
patterns.insert <LinalgPromotionPattern<FillOp>>(
161
172
ctx,
162
173
LinalgPromotionOptions ()
163
174
.setOperandsToPromote ({0 })
164
175
.setUseFullTileBuffers ({true })
165
176
.setAlignment (32 ),
166
- LinalgMarker ({" _promote_views_aligned_" }, " _views_aligned_promoted_" ));
177
+ LinalgMarker (Identifier::get (" _promote_views_aligned_" , ctx),
178
+ Identifier::get (" _views_aligned_promoted_" , ctx)));
167
179
168
180
applyPatternsAndFoldGreedily (funcOp, patterns);
169
181
@@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
176
188
static void fillL1TilingAndMatmulToVectorPatterns (
177
189
FuncOp funcOp, StringRef startMarker,
178
190
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
179
- MLIRContext *context = funcOp.getContext ();
191
+ MLIRContext *ctx = funcOp.getContext ();
180
192
patternsVector.emplace_back (LinalgTilingPattern<MatmulOp>(
181
- context ,
193
+ ctx ,
182
194
LinalgTilingOptions ().setTileSizes ({8 , 12 , 16 }).setInterchange ({1 , 0 , 2 }),
183
- LinalgMarker ({startMarker}, " L1" )));
195
+ LinalgMarker (Identifier::get (startMarker, ctx),
196
+ Identifier::get (" L1" , ctx))));
184
197
185
198
patternsVector.emplace_back (LinalgPromotionPattern<MatmulOp>(
186
- context , LinalgPromotionOptions ().useFullTileBuffersByDefault (),
187
- LinalgMarker ({ " L1" }, " VEC" )));
199
+ ctx , LinalgPromotionOptions ().useFullTileBuffersByDefault (),
200
+ LinalgMarker (Identifier::get ( " L1" , ctx), Identifier::get ( " VEC" , ctx) )));
188
201
189
- patternsVector.emplace_back (
190
- LinalgVectorizationPattern<MatmulOp>(context , LinalgMarker ({ " VEC" } )));
202
+ patternsVector.emplace_back (LinalgVectorizationPattern<MatmulOp>(
203
+ ctx , LinalgMarker (Identifier::get ( " VEC" , ctx) )));
191
204
patternsVector.back ()
192
205
.insert <LinalgVectorizationPattern<FillOp>,
193
- LinalgVectorizationPattern<CopyOp>>(context );
206
+ LinalgVectorizationPattern<CopyOp>>(ctx );
194
207
}
195
208
196
209
// ===----------------------------------------------------------------------===//
@@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
231
244
return success ();
232
245
}
233
246
234
- void fillPromotionCallBackPatterns (MLIRContext *context ,
247
+ void fillPromotionCallBackPatterns (MLIRContext *ctx ,
235
248
OwningRewritePatternList &patterns) {
236
249
patterns.insert <LinalgTilingPattern<MatmulOp>>(
237
- context, LinalgTilingOptions ().setTileSizes ({16 , 16 , 16 }),
238
- LinalgMarker ({" START" }, " PROMOTE" ));
250
+ ctx, LinalgTilingOptions ().setTileSizes ({16 , 16 , 16 }),
251
+ LinalgMarker (Identifier::get (" START" , ctx),
252
+ Identifier::get (" PROMOTE" , ctx)));
239
253
patterns.insert <LinalgPromotionPattern<MatmulOp>>(
240
- context ,
254
+ ctx ,
241
255
LinalgPromotionOptions ()
242
256
.setOperandsToPromote ({0 , 2 })
243
257
.setUseFullTileBuffers ({false , false })
@@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
251
265
copyCallBackFn (b, src, dst, true );
252
266
return success ();
253
267
}),
254
- LinalgMarker ({ " PROMOTE" } ));
268
+ LinalgMarker (Identifier::get ( " PROMOTE" , ctx) ));
255
269
}
256
270
257
271
static void
@@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
261
275
MLIRContext *ctx = funcOp.getContext ();
262
276
SmallVector<OwningRewritePatternList, 4 > stage1Patterns;
263
277
if (testMatmulToVectorPatterns1dTiling) {
264
- fillL1TilingAndMatmulToVectorPatterns (funcOp, " START" , stage1Patterns);
278
+ fillL1TilingAndMatmulToVectorPatterns (funcOp, Identifier::get (" START" , ctx),
279
+ stage1Patterns);
265
280
} else if (testMatmulToVectorPatterns2dTiling) {
266
- stage1Patterns.emplace_back (
267
- LinalgTilingPattern<MatmulOp>(ctx,
268
- LinalgTilingOptions ()
269
- .setTileSizes ({768 , 264 , 768 })
270
- .setInterchange ({1 , 2 , 0 }),
271
- LinalgMarker ({" START" }, " L2" )));
272
- fillL1TilingAndMatmulToVectorPatterns (funcOp, " L2" , stage1Patterns);
281
+ stage1Patterns.emplace_back (LinalgTilingPattern<MatmulOp>(
282
+ ctx,
283
+ LinalgTilingOptions ()
284
+ .setTileSizes ({768 , 264 , 768 })
285
+ .setInterchange ({1 , 2 , 0 }),
286
+ LinalgMarker (Identifier::get (" START" , ctx),
287
+ Identifier::get (" L2" , ctx))));
288
+ fillL1TilingAndMatmulToVectorPatterns (funcOp, Identifier::get (" L2" , ctx),
289
+ stage1Patterns);
273
290
}
274
291
OwningRewritePatternList stage2Patterns =
275
292
getLinalgTilingCanonicalizationPatterns (ctx);
0 commit comments