Skip to content

Commit 50febde

Browse files
authored
[mlir][vector] Bugfix of linearize vector.extract (#106836)
This patch add check for `vector.extract` with scalar type, which is not allowed when linearize `vector.extract`. Fix #106162.
1 parent 030e4d0 commit 50febde

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ struct LinearizeVectorExtract final
337337
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
338338
ConversionPatternRewriter &rewriter) const override {
339339
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
340+
if (!dstTy)
341+
return rewriter.notifyMatchFailure(extractOp,
342+
"expected n-D vector type.");
343+
340344
if (extractOp.getVector().getType().isScalable() ||
341345
cast<VectorType>(dstTy).isScalable())
342346
return rewriter.notifyMatchFailure(extractOp,

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,15 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
306306
// ALL: return %[[RES]] : vector<2x8x[4]xf32>
307307
return %0 : vector<2x8x[4]xf32>
308308
}
309+
310+
// -----
311+
312+
// ALL-LABEL: test_vector_extract_scalar
313+
func.func @test_vector_extract_scalar() {
314+
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
315+
// ALL-NOT: vector.shuffle
316+
// ALL: vector.extract
317+
// ALL-NOT: vector.shuffle
318+
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
319+
return
320+
}

0 commit comments

Comments
 (0)