Skip to content

Commit b1b0384

Browse files
committed
[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize
1 parent 7bea41e commit b1b0384

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,12 @@ struct LinearizeVectorExtractStridedSlice final
151151
LogicalResult
152152
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
153153
ConversionPatternRewriter &rewriter) const override {
154-
Type dstType = getTypeConverter()->convertType(extractOp.getType());
155-
assert(!(extractOp.getVector().getType().isScalable() ||
156-
cast<VectorType>(dstType).isScalable()) &&
157-
"scalable vectors are not supported.");
154+
VectorType dstType =
155+
getTypeConverter()->convertType<VectorType>(extractOp.getType());
156+
assert(dstType && "vector type destination expected.");
157+
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
158+
return rewriter.notifyMatchFailure(extractOp,
159+
"scalable vectors are not supported.");
158160
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
159161
return rewriter.notifyMatchFailure(
160162
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -264,10 +266,14 @@ struct LinearizeVectorShuffle final
264266
LogicalResult
265267
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266268
ConversionPatternRewriter &rewriter) const override {
267-
Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
269+
VectorType dstType =
270+
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
271+
assert(dstType && "vector type destination expected.");
272+
// The assert is used because vector.shuffle does not support scalable
273+
// vectors.
268274
assert(!(shuffleOp.getV1VectorType().isScalable() ||
269275
shuffleOp.getV2VectorType().isScalable() ||
270-
cast<VectorType>(dstType).isScalable()) &&
276+
dstType.isScalable()) &&
271277
"scalable vectors are not supported.");
272278
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
273279
return rewriter.notifyMatchFailure(
@@ -336,9 +342,10 @@ struct LinearizeVectorExtract final
336342
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
337343
ConversionPatternRewriter &rewriter) const override {
338344
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
339-
assert(!(extractOp.getVector().getType().isScalable() ||
340-
cast<VectorType>(dstTy).isScalable()) &&
341-
"scalable vectors are not supported.");
345+
if (extractOp.getVector().getType().isScalable() ||
346+
cast<VectorType>(dstTy).isScalable())
347+
return rewriter.notifyMatchFailure(extractOp,
348+
"scalable vectors are not supported.");
342349
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
343350
return rewriter.notifyMatchFailure(
344351
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -394,10 +401,12 @@ struct LinearizeVectorInsert final
394401
LogicalResult
395402
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
396403
ConversionPatternRewriter &rewriter) const override {
397-
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
398-
assert(!(insertOp.getDestVectorType().isScalable() ||
399-
cast<VectorType>(dstTy).isScalable()) &&
400-
"scalable vectors are not supported.");
404+
VectorType dstTy = getTypeConverter()->convertType<VectorType>(
405+
insertOp.getDestVectorType());
406+
assert(dstTy && "vector type destination expected.");
407+
if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
408+
return rewriter.notifyMatchFailure(insertOp,
409+
"scalable vectors are not supported.");
401410

402411
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
403412
targetVectorBitWidth))

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
129129
// -----
130130

131131
// ALL-LABEL: func.func @test_scalable_no_linearize(
132-
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
133-
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
132+
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
133+
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
134134
// ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
135135
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
136136

@@ -177,6 +177,17 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
177177
return %0 : vector<2x2xf32>
178178
}
179179

180+
// ALL-LABEL: func.func @test_extract_strided_slice_scalable(
181+
// ALL-SAME: %[[VAL_0:.*]]: vector<2x[2]xf32>) -> vector<1x[2]xf32> {
182+
func.func @test_extract_strided_slice_scalable(%arg0: vector<2x[2]xf32>) -> vector<1x[2]xf32> {
183+
// CHECK-NOT: vector.shuffle
184+
// CHECK-NOT: vector.shape_cast
185+
// ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [1, 2], strides = [1, 1]} : vector<2x[2]xf32> to vector<1x[2]xf32>
186+
%0 = vector.extract_strided_slice %arg0 { sizes = [1, 2], strides = [1, 1], offsets = [1, 0] } : vector<2x[2]xf32> to vector<1x[2]xf32>
187+
// ALL: return %[[RES]] : vector<1x[2]xf32>
188+
return %0 : vector<1x[2]xf32>
189+
}
190+
180191
// -----
181192
// ALL-LABEL: test_extract_strided_slice_2
182193
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -246,6 +257,16 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
246257
return %0 : vector<8x2xf32>
247258
}
248259

260+
// ALL-LABEL: func.func @test_vector_extract_scalable(
261+
// ALL-SAME: %[[VAL_0:.*]]: vector<2x[2]xf32>) -> f32 {
262+
func.func @test_vector_extract_scalable(%arg1: vector<2x[2]xf32>) -> f32 {
263+
// CHECK-NOT: vector.shuffle
264+
// CHECK-NOT: vector.shape_cast
265+
// ALL: %[[RES:.*]] = vector.extract %[[VAL_0]][0, 0] : f32 from vector<2x[2]xf32>
266+
%0 = vector.extract %arg1[0, 0]: f32 from vector<2x[2]xf32>
267+
// ALL: return %[[RES]] : f32
268+
return %0 : f32
269+
}
249270
// -----
250271
// ALL-LABEL: test_vector_insert
251272
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -274,3 +295,14 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
274295
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
275296
return %0 : vector<2x8x4xf32>
276297
}
298+
299+
// ALL-LABEL: func.func @test_vector_insert_scalable(
300+
// ALL-SAME: %[[VAL_0:.*]]: vector<2x[2]xf32>, %[[VAL_1:.*]]: f32) -> vector<2x[2]xf32> {
301+
func.func @test_vector_insert_scalable(%arg0: vector<2x[2]xf32>, %arg1: f32) -> vector<2x[2]xf32> {
302+
// CHECK-NOT: vector.shuffle
303+
// CHECK-NOT: vector.shape_cast
304+
// ALL: %[[RES:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0, 0] : f32 into vector<2x[2]xf32>
305+
%0 = vector.insert %arg1, %arg0[0, 0]: f32 into vector<2x[2]xf32>
306+
// ALL: return %[[RES]] : vector<2x[2]xf32>
307+
return %0 : vector<2x[2]xf32>
308+
}

0 commit comments

Comments
 (0)