13
13
14
14
#include " mlir/Dialect/Vector/IR/VectorOps.h"
15
15
#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16
+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
16
17
#include " mlir/IR/BuiltinTypes.h"
17
18
#include " mlir/IR/PatternMatch.h"
18
19
@@ -23,66 +24,62 @@ using namespace mlir::vector;
23
24
24
25
namespace {
25
26
26
- // / Progressive lowering of InterleaveOp.
27
- // /
28
- // / Each leading dimension is unrolled until the result of the interleave is
29
- // / rank 1 (or the dimension is scalable, so can't be unrolled).
27
+ // / A one-shot unrolling of vector.interleave to the `targetRank`.
30
28
// /
31
29
// / Example:
32
30
// /
31
+ // / ```mlir
32
+ // / vector.interleave %a, %b : vector<1x2x3x4xi64>
33
33
// / ```
34
- // / %0 = vector.interleave %lhs, %rhs : vector<2x...8xty>
35
- // / ```
36
- // / Becomes:
37
- // / ```
38
- // / %lhs_0 = vector.extract %lhs[0]
39
- // / %rhs_0 = vector.extract %rhs[0]
40
- // / %lhs_1 = vector.extract %lhs[1]
41
- // / %rhs_1 = vector.extract %rhs[1]
42
- // / %zip_0 = vector.interleave %lhs_0, %rhs_0
43
- // / %zip_1 = vector.interleave %lhs_1, %rhs_1
44
- // / %res_0 = vector.insert %zip_0, %undef[0]
45
- // / %0 = vector.insert %zip_1, %res_0[1]
34
+ // / Would be unrolled to:
35
+ // / ```mlir
36
+ // / %result = arith.constant dense<0> : vector<1x2x3x8xi64>
37
+ // / %0 = vector.extract %a[0, 0, 0] ─┐
38
+ // / : vector<4xi64> from vector<1x2x3x4xi64> |
39
+ // / %1 = vector.extract %b[0, 0, 0] |
40
+ // / : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
41
+ // / %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
42
+ // / %3 = vector.insert %2, %result [0, 0, 0] |
43
+ // / : vector<8xi64> into vector<1x2x3x8xi64> ┘
46
44
// / ```
47
45
// /
48
- // / If %zip_0 and %zip_1 still have a rank > 1 they will be unrolled again
49
- // / following the same pattern .
50
- class InterleaveOpLowering : public OpRewritePattern <vector::InterleaveOp> {
46
+ // / Note: If any leading dimension before the `targetRank` is scalable the
47
+ // / unrolling will stop before the scalable dimension .
48
+ class UnrollInterleaveOp : public OpRewritePattern <vector::InterleaveOp> {
51
49
public:
52
- using OpRewritePattern::OpRewritePattern;
50
+ UnrollInterleaveOp (int64_t targetRank, MLIRContext *context,
51
+ PatternBenefit benefit = 1 )
52
+ : OpRewritePattern(context, benefit), targetRank(targetRank){};
53
53
54
54
LogicalResult matchAndRewrite (vector::InterleaveOp op,
55
55
PatternRewriter &rewriter) const override {
56
56
VectorType resultType = op.getResultVectorType ();
57
- // 1-D vector.interleave ops can be directly lowered to LLVM (later).
58
- if (resultType. getRank () == 1 )
57
+ auto unrollIterator = vector::createUnrollIterator (resultType, targetRank);
58
+ if (!unrollIterator )
59
59
return failure ();
60
60
61
- // Below we unroll the leading (or front) dimension. If that dimension is
62
- // scalable we can't unroll it.
63
- if (resultType.getScalableDims ().front ())
64
- return failure ();
65
-
66
- // n-D case: Unroll the leading dimension.
67
61
auto loc = op.getLoc ();
68
62
Value result = rewriter.create <arith::ConstantOp>(
69
63
loc, resultType, rewriter.getZeroAttr (resultType));
70
- for (int idx = 0 , end = resultType. getDimSize ( 0 ); idx < end; ++idx ) {
71
- Value extractLhs = rewriter.create <ExtractOp>(loc, op.getLhs (), idx );
72
- Value extractRhs = rewriter.create <ExtractOp>(loc, op.getRhs (), idx );
64
+ for (auto position : *unrollIterator ) {
65
+ Value extractLhs = rewriter.create <ExtractOp>(loc, op.getLhs (), position );
66
+ Value extractRhs = rewriter.create <ExtractOp>(loc, op.getRhs (), position );
73
67
Value interleave =
74
68
rewriter.create <InterleaveOp>(loc, extractLhs, extractRhs);
75
- result = rewriter.create <InsertOp>(loc, interleave, result, idx );
69
+ result = rewriter.create <InsertOp>(loc, interleave, result, position );
76
70
}
77
71
78
72
rewriter.replaceOp (op, result);
79
73
return success ();
80
74
}
75
+
76
+ private:
77
+ int64_t targetRank = 1 ;
81
78
};
82
79
83
80
} // namespace
84
81
85
82
void mlir::vector::populateVectorInterleaveLoweringPatterns (
86
- RewritePatternSet &patterns, PatternBenefit benefit) {
87
- patterns.add <InterleaveOpLowering>( patterns.getContext (), benefit);
83
+ RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
84
+ patterns.add <UnrollInterleaveOp>(targetRank, patterns.getContext (), benefit);
88
85
}
0 commit comments