13
13
#include < utility>
14
14
15
15
#include " mlir/Dialect/Vector/IR/VectorOps.h"
16
- #include " mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
17
16
#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
18
- #include " mlir/IR/BuiltinOps.h"
19
17
#include " mlir/IR/PatternMatch.h"
20
18
#include " mlir/Support/LogicalResult.h"
21
19
20
+ #include " mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
21
+
22
22
namespace mlir {
23
23
class RewritePatternSet ;
24
24
@@ -57,7 +57,7 @@ struct UnrollVectorOptions {
57
57
}
58
58
59
59
// / Function that returns the traversal order (in terms of "for loop order",
60
- // / i.e. slowest varying dimension to fastest varying dimension) that shoudl
60
+ // / i.e. slowest varying dimension to fastest varying dimension) that should
61
61
// / be used when unrolling the given operation into units of the native vector
62
62
// / size.
63
63
using UnrollTraversalOrderFnType =
@@ -70,10 +70,6 @@ struct UnrollVectorOptions {
70
70
}
71
71
};
72
72
73
- // ===----------------------------------------------------------------------===//
74
- // Vector transformation exposed as populate functions over rewrite patterns.
75
- // ===----------------------------------------------------------------------===//
76
-
77
73
// / Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
78
74
// / semantics to a contraction with MMT semantics (matrix matrix multiplication
79
75
// / with the RHS transposed). This specific form is meant to have the vector
@@ -134,10 +130,6 @@ void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
134
130
void populateVectorTransferFullPartialPatterns (
135
131
RewritePatternSet &patterns, const VectorTransformsOptions &options);
136
132
137
- // ===----------------------------------------------------------------------===//
138
- // Vector.transfer patterns.
139
- // ===----------------------------------------------------------------------===//
140
-
141
133
// / Collect a set of patterns to reduce the rank of the operands of vector
142
134
// / transfer ops to operate on the largest contigious vector.
143
135
// / These patterns are useful when lowering to dialects with 1d vector type
@@ -263,6 +255,49 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
263
255
const UnrollVectorOptions &options,
264
256
PatternBenefit benefit = 1 );
265
257
258
+ // / Collect a set of vector.shape_cast folding patterns.
259
+ void populateShapeCastFoldingPatterns (RewritePatternSet &patterns,
260
+ PatternBenefit benefit = 1 );
261
+
262
+ // / Collect a set of leading one dimension removal patterns.
263
+ // /
264
+ // / These patterns insert vector.shape_cast to remove leading one dimensions
265
+ // / to expose more canonical forms of read/write/insert/extract operations.
266
+ // / With them, there are more chances that we can cancel out extract-insert
267
+ // / pairs or forward write-read pairs.
268
+ void populateCastAwayVectorLeadingOneDimPatterns (RewritePatternSet &patterns,
269
+ PatternBenefit benefit = 1 );
270
+
271
+ // / Collect a set of one dimension removal patterns.
272
+ // /
273
+ // / These patterns insert rank-reducing memref.subview ops to remove one
274
+ // / dimensions. With them, there are more chances that we can avoid
275
+ // / potentially expensive vector.shape_cast operations.
276
+ void populateVectorTransferDropUnitDimsPatterns (RewritePatternSet &patterns,
277
+ PatternBenefit benefit = 1 );
278
+
279
+ // / Collect a set of patterns to flatten n-D vector transfers on contiguous
280
+ // / memref.
281
+ // /
282
+ // / These patterns insert memref.collapse_shape + vector.shape_cast patterns
283
+ // / to transform multiple small n-D transfers into a larger 1-D transfer where
284
+ // / the memref contiguity properties allow it.
285
+ void populateFlattenVectorTransferPatterns (RewritePatternSet &patterns,
286
+ PatternBenefit benefit = 1 );
287
+
288
+ // / Collect a set of patterns that bubble up/down bitcast ops.
289
+ // /
290
+ // / These patterns move vector.bitcast ops to be before insert ops or after
291
+ // / extract ops where suitable. With them, bitcast will happen on smaller
292
+ // / vectors and there are more chances to share extract/insert ops.
293
+ void populateBubbleVectorBitCastOpPatterns (RewritePatternSet &patterns,
294
+ PatternBenefit benefit = 1 );
295
+
296
+ // / These patterns materialize masks for various vector ops such as transfers.
297
+ void populateVectorMaskMaterializationPatterns (RewritePatternSet &patterns,
298
+ bool force32BitVectorIndices,
299
+ PatternBenefit benefit = 1 );
300
+
266
301
} // namespace vector
267
302
} // namespace mlir
268
303
0 commit comments