Skip to content

Commit 2b04291

Browse files
authored
[mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering (#124861)
It looks like scalable `vector.insertslice/extractslice` ops made their way through lowering patterns that generate `vector.shuffle` ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified. This PR fixes the issue by preventing scalable vector.insertslice/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insertslice/extractelement ops using an existing patter.
1 parent 213a939 commit 2b04291

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
9696
PatternRewriter &rewriter) const override {
9797
auto srcType = op.getSourceVectorType();
9898
auto dstType = op.getDestVectorType();
99+
int64_t srcRank = srcType.getRank();
100+
101+
// Scalable vectors are not supported by vector shuffle.
102+
if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
103+
return failure();
99104

100105
if (op.getOffsets().getValue().empty())
101106
return failure();
102107

103-
int64_t srcRank = srcType.getRank();
104108
int64_t dstRank = dstType.getRank();
105109
assert(dstRank >= srcRank);
106110
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
184188
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
185189
PatternRewriter &rewriter) const override {
186190
auto dstType = op.getType();
191+
auto srcType = op.getSourceVectorType();
192+
193+
// Scalable vectors are not supported by vector shuffle.
194+
if (dstType.isScalable() || srcType.isScalable())
195+
return failure();
187196

188197
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
189198

@@ -309,6 +318,8 @@ class DecomposeNDExtractStridedSlice
309318
}
310319
};
311320

321+
// TODO: Make sure these `populate*` patterns are tested in isolation.
322+
312323
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
313324
RewritePatternSet &patterns, PatternBenefit benefit) {
314325
patterns.add<DecomposeDifferentRankInsertStridedSlice,
@@ -331,4 +342,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
331342
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
332343
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
333344
benefit);
345+
// Generate chains of extract/insert ops for scalable vectors only as they
346+
// can't be lowered to vector shuffles.
347+
populateVectorExtractStridedSliceToExtractInsertChainPatterns(
348+
patterns,
349+
/*controlFn=*/
350+
[](ExtractStridedSliceOp op) {
351+
return op.getType().isScalable() ||
352+
op.getSourceVectorType().isScalable();
353+
},
354+
benefit);
334355
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
20262026
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
20272027
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
20282028
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
2029-
// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
2030-
// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
2031-
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
2032-
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
2033-
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
2034-
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
2035-
// CHECK: return %[[T5]]
2029+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
2030+
// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
2031+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
2032+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
2033+
// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
2034+
// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
2035+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
2036+
// CHECK: return %[[RES]]
20362037

20372038
// -----
20382039

0 commit comments

Comments
 (0)