Skip to content

Commit 6b7d618

Browse files
committed
Fixup: Create and use vector::createUnrollIterator() until
Instead of progressively unrolling a leading dimension at a time, this now uses `vector::createUnrollIterator()` which returns an iterator for all leading dimensions of a vector type (until a target rank).
1 parent c3a5790 commit 6b7d618

File tree

5 files changed

+102
-37
lines changed

5 files changed

+102
-37
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,11 @@ void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
266266

267267
/// Populate the pattern set with the following patterns:
268268
///
269-
/// [InterleaveOpLowering]
270-
/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
271-
/// InterleaveOp until dim 1.
269+
/// [UnrollInterleaveOp]
270+
/// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
271+
/// InterleaveOp (of `targetRank`) + InsertOp.
272272
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
273+
int64_t targetRank = 1,
273274
PatternBenefit benefit = 1);
274275

275276
} // namespace vector

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1010
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1111

12+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1213
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1314
#include "mlir/IR/BuiltinAttributes.h"
1415
#include "mlir/Support/LLVM.h"
@@ -75,6 +76,28 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
7576
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
7677
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
7778

79+
/// Returns an iterator for all positions in the leading dimensions of `vType`
80+
/// up to the `targetRank`. If any leading dimension before the `targetRank` is
81+
/// scalable (so cannot be unrolled), it will return an iterator for positions
82+
/// up to the first scalable dimension.
83+
///
84+
/// If no leading dimensions can be unrolled an empty optional will be returned.
85+
///
86+
/// Examples:
87+
///
88+
/// For vType = vector<2x3x4> and targetRank = 1
89+
///
90+
/// The resulting iterator will yield:
91+
/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
92+
///
93+
/// For vType = vector<3x[4]x5> and targetRank = 0
94+
///
95+
/// The scalable dimension blocks unrolling so the iterator yields only:
96+
/// [0], [1], [2]
97+
///
98+
std::optional<StaticTileOffsetRange>
99+
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
100+
78101
} // namespace vector
79102

80103
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1515
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1617
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/PatternMatch.h"
1819

@@ -23,66 +24,62 @@ using namespace mlir::vector;
2324

2425
namespace {
2526

26-
/// Progressive lowering of InterleaveOp.
27-
///
28-
/// Each leading dimension is unrolled until the result of the interleave is
29-
/// rank 1 (or the dimension is scalable, so can't be unrolled).
27+
/// A one-shot unrolling of vector.interleave to the `targetRank`.
3028
///
3129
/// Example:
3230
///
31+
/// ```mlir
32+
/// vector.interleave %a, %b : vector<1x2x3x4xi64>
3333
/// ```
34-
/// %0 = vector.interleave %lhs, %rhs : vector<2x...8xty>
35-
/// ```
36-
/// Becomes:
37-
/// ```
38-
/// %lhs_0 = vector.extract %lhs[0]
39-
/// %rhs_0 = vector.extract %rhs[0]
40-
/// %lhs_1 = vector.extract %lhs[1]
41-
/// %rhs_1 = vector.extract %rhs[1]
42-
/// %zip_0 = vector.interleave %lhs_0, %rhs_0
43-
/// %zip_1 = vector.interleave %lhs_1, %rhs_1
44-
/// %res_0 = vector.insert %zip_0, %undef[0]
45-
/// %0 = vector.insert %zip_1, %res_0[1]
34+
/// Would be unrolled to:
35+
/// ```mlir
36+
/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
37+
/// %0 = vector.extract %a[0, 0, 0] ─┐
38+
/// : vector<4xi64> from vector<1x2x3x4xi64> |
39+
/// %1 = vector.extract %b[0, 0, 0] |
40+
/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
41+
/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
42+
/// %3 = vector.insert %2, %result [0, 0, 0] |
43+
/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
4644
/// ```
4745
///
48-
/// If %zip_0 and %zip_1 still have a rank > 1 they will be unrolled again
49-
/// following the same pattern.
50-
class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
46+
/// Note: If any leading dimension before the `targetRank` is scalable the
47+
/// unrolling will stop before the scalable dimension.
48+
class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
5149
public:
52-
using OpRewritePattern::OpRewritePattern;
50+
UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
51+
PatternBenefit benefit = 1)
52+
: OpRewritePattern(context, benefit), targetRank(targetRank){};
5353

5454
LogicalResult matchAndRewrite(vector::InterleaveOp op,
5555
PatternRewriter &rewriter) const override {
5656
VectorType resultType = op.getResultVectorType();
57-
// 1-D vector.interleave ops can be directly lowered to LLVM (later).
58-
if (resultType.getRank() == 1)
57+
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
58+
if (!unrollIterator)
5959
return failure();
6060

61-
// Below we unroll the leading (or front) dimension. If that dimension is
62-
// scalable we can't unroll it.
63-
if (resultType.getScalableDims().front())
64-
return failure();
65-
66-
// n-D case: Unroll the leading dimension.
6761
auto loc = op.getLoc();
6862
Value result = rewriter.create<arith::ConstantOp>(
6963
loc, resultType, rewriter.getZeroAttr(resultType));
70-
for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
71-
Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
72-
Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
64+
for (auto position : *unrollIterator) {
65+
Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position);
66+
Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position);
7367
Value interleave =
7468
rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
75-
result = rewriter.create<InsertOp>(loc, interleave, result, idx);
69+
result = rewriter.create<InsertOp>(loc, interleave, result, position);
7670
}
7771

7872
rewriter.replaceOp(op, result);
7973
return success();
8074
}
75+
76+
private:
77+
int64_t targetRank = 1;
8178
};
8279

8380
} // namespace
8481

8582
void mlir::vector::populateVectorInterleaveLoweringPatterns(
86-
RewritePatternSet &patterns, PatternBenefit benefit) {
87-
patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
83+
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
84+
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
8885
}

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,25 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
303303

304304
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
305305
}
306+
307+
std::optional<StaticTileOffsetRange>
308+
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
309+
if (vType.getRank() <= targetRank)
310+
return {};
311+
// Attempt to unroll until targetRank or the first scalable dimension (which
312+
// cannot be unrolled).
313+
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
314+
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
315+
auto it =
316+
std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
317+
auto firstScalableDim = it - scalableDimsToUnroll.begin();
318+
if (firstScalableDim == 0)
319+
return {};
320+
// All scalable dimensions should be removed now.
321+
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
322+
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
323+
"unexpected leading scalable dimension");
324+
// Create an unroll iterator for leading dimensions.
325+
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
326+
return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
327+
}

mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
3636
return %0 : vector<2x[16]xi16>
3737
}
3838

39+
// CHECK-LABEL: @vector_interleave_4d
40+
// CHECK-SAME: %[[LHS:.*]]: vector<1x2x3x4xi64>, %[[RHS:.*]]: vector<1x2x3x4xi64>)
41+
func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>) -> vector<1x2x3x8xi64>
42+
{
43+
// CHECK: %[[LHS_0:.*]] = vector.extract %[[LHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
44+
// CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
45+
// CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
46+
// CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
47+
// CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
48+
%0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
49+
return %0 : vector<1x2x3x8xi64>
50+
}
51+
52+
// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
53+
func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
54+
{
55+
// The scalable dim blocks unrolling so only the first two dims are unrolled.
56+
// CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
57+
%0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
58+
return %0 : vector<1x3x[2]x2x3x8xf16>
59+
}
60+
3961
module attributes {transform.with_named_sequence} {
4062
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
4163
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)