@@ -27,14 +27,22 @@ struct TestVectorToVectorConversion
27
27
void runOnFunction () override {
28
28
OwningRewritePatternList patterns;
29
29
auto *ctx = &getContext ();
30
- patterns.insert <UnrollVectorPattern<AddFOp>>(
31
- ctx, UnrollVectorOptions ().setNativeShape (ArrayRef<int64_t >{2 , 2 }));
32
- patterns.insert <UnrollVectorPattern<vector::ContractionOp>>(
33
- ctx, UnrollVectorOptions ().setNativeShape (ArrayRef<int64_t >{2 , 2 , 2 }));
30
+ patterns.insert <UnrollVectorPattern>(
31
+ ctx, UnrollVectorOptions ().setNativeShapeFn (getShape));
34
32
populateVectorToVectorCanonicalizationPatterns (patterns, ctx);
35
33
populateVectorToVectorTransformationPatterns (patterns, ctx);
36
34
applyPatternsAndFoldGreedily (getFunction (), std::move (patterns));
37
35
}
36
+
37
+ private:
38
+ // Return the target shape based on op type.
39
+ static Optional<SmallVector<int64_t , 4 >> getShape (Operation *op) {
40
+ if (isa<AddFOp>(op))
41
+ return SmallVector<int64_t , 4 >(2 , 2 );
42
+ if (isa<vector::ContractionOp>(op))
43
+ return SmallVector<int64_t , 4 >(3 , 2 );
44
+ return llvm::None;
45
+ }
38
46
};
39
47
40
48
struct TestVectorSlicesConversion
@@ -120,8 +128,11 @@ struct TestVectorUnrollingPatterns
120
128
void runOnFunction () override {
121
129
MLIRContext *ctx = &getContext ();
122
130
OwningRewritePatternList patterns;
123
- patterns.insert <UnrollVectorPattern<AddFOp>>(
124
- ctx, UnrollVectorOptions ().setNativeShape (ArrayRef<int64_t >{2 , 2 }));
131
+ patterns.insert <UnrollVectorPattern>(
132
+ ctx, UnrollVectorOptions ()
133
+ .setNativeShape (ArrayRef<int64_t >{2 , 2 })
134
+ .setFilterConstraint (
135
+ [](Operation *op) { return success (isa<AddFOp>(op)); }));
125
136
126
137
if (unrollBasedOnType) {
127
138
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
@@ -137,12 +148,19 @@ struct TestVectorUnrollingPatterns
137
148
}
138
149
return nativeShape;
139
150
};
140
- patterns.insert <UnrollVectorPattern<vector::ContractionOp>>(
141
- ctx, UnrollVectorOptions ().setNativeShapeFn (nativeShapeFn));
151
+ patterns.insert <UnrollVectorPattern>(
152
+ ctx, UnrollVectorOptions ()
153
+ .setNativeShapeFn (nativeShapeFn)
154
+ .setFilterConstraint ([](Operation *op) {
155
+ return success (isa<ContractionOp>(op));
156
+ }));
142
157
} else {
143
- patterns.insert <UnrollVectorPattern<vector::ContractionOp>>(
144
- ctx,
145
- UnrollVectorOptions ().setNativeShape (ArrayRef<int64_t >{2 , 2 , 2 }));
158
+ patterns.insert <UnrollVectorPattern>(
159
+ ctx, UnrollVectorOptions ()
160
+ .setNativeShape (ArrayRef<int64_t >{2 , 2 , 2 })
161
+ .setFilterConstraint ([](Operation *op) {
162
+ return success (isa<ContractionOp>(op));
163
+ }));
146
164
}
147
165
populateVectorToVectorCanonicalizationPatterns (patterns, ctx);
148
166
populateVectorToVectorTransformationPatterns (patterns, ctx);
@@ -273,10 +291,14 @@ struct TestVectorTransferUnrollingPatterns
273
291
void runOnFunction () override {
274
292
MLIRContext *ctx = &getContext ();
275
293
OwningRewritePatternList patterns;
276
- patterns.insert <UnrollVectorPattern<vector::TransferReadOp>>(
277
- ctx, UnrollVectorOptions ().setNativeShape (ArrayRef<int64_t >{2 , 2 }));
278
- patterns.insert <UnrollVectorPattern<vector::TransferWriteOp>>(
279
- ctx, UnrollVectorOptions ().setNativeShape (ArrayRef<int64_t >{2 , 2 }));
294
+ patterns.insert <UnrollVectorPattern>(
295
+ ctx,
296
+ UnrollVectorOptions ()
297
+ .setNativeShape (ArrayRef<int64_t >{2 , 2 })
298
+ .setFilterConstraint ([](Operation *op) {
299
+ return success (
300
+ isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
301
+ }));
280
302
populateVectorToVectorCanonicalizationPatterns (patterns, ctx);
281
303
populateVectorToVectorTransformationPatterns (patterns, ctx);
282
304
applyPatternsAndFoldGreedily (getFunction (), std::move (patterns));
0 commit comments