Skip to content

Commit dfb8537

Browse files
committed
fixup! [mlir][vector] Restrict narrow-type-emulation patterns
Update tests that were still loading/storing 1-D vectors
1 parent 711a548 commit dfb8537

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
217217
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
218218
ConversionPatternRewriter &rewriter) const override {
219219

220+
// See #115653
220221
if (op.getValueToStore().getType().getRank() != 1)
221222
return rewriter.notifyMatchFailure(op,
222223
"only 1-D vectors are supported ATM");
@@ -287,6 +288,7 @@ struct ConvertVectorMaskedStore final
287288
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
288289
ConversionPatternRewriter &rewriter) const override {
289290

291+
// See #115653
290292
if (op.getValueToStore().getType().getRank() != 1)
291293
return rewriter.notifyMatchFailure(op,
292294
"only 1-D vectors are supported ATM");
@@ -380,6 +382,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
380382
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
381383
ConversionPatternRewriter &rewriter) const override {
382384

385+
// See #115653
383386
if (op.getVectorType().getRank() != 1)
384387
return rewriter.notifyMatchFailure(op,
385388
"only 1-D vectors are supported ATM");
@@ -485,6 +488,7 @@ struct ConvertVectorMaskedLoad final
485488
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
486489
ConversionPatternRewriter &rewriter) const override {
487490

491+
// See #115653
488492
if (op.getVectorType().getRank() != 1)
489493
return rewriter.notifyMatchFailure(op,
490494
"only 1-D vectors are supported ATM");
@@ -640,6 +644,7 @@ struct ConvertVectorTransferRead final
640644
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
641645
ConversionPatternRewriter &rewriter) const override {
642646

647+
// See #115653
643648
if (op.getVectorType().getRank() != 1)
644649
return rewriter.notifyMatchFailure(op,
645650
"only 1-D vectors are supported ATM");

mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,11 @@ func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x
6565
%c16 = arith.constant 16 : index
6666
%c8 = arith.constant 8 : index
6767
%cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
68-
%cst_2 = arith.constant dense<0> : vector<16xi4>
68+
%cst_2 = arith.constant dense<0> : vector<8x16xi4>
6969
%27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
7070
%48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
71-
%49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
72-
%50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
73-
%63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
71+
%50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4>
72+
%63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4>
7473
return %63 : vector<8x8x16xi4>
7574
}
7675

@@ -84,9 +83,9 @@ func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x
8483
/// vector.store
8584
///----------------------------------------------------------------------------------------
8685

87-
func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
86+
func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg2: index) {
8887
%0 = memref.alloc() : memref<4x8xi8>
89-
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
88+
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<2x8xi8>
9089
return
9190
}
9291

@@ -100,10 +99,10 @@ func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2:
10099
/// vector.maskedstore
101100
///----------------------------------------------------------------------------------------
102101

103-
func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
102+
func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) {
104103
%0 = memref.alloc() : memref<3x8xi8>
105-
%mask = vector.create_mask %arg2 : vector<8xi1>
106-
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
104+
%mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1>
105+
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8>
107106
return
108107
}
109108

0 commit comments

Comments
 (0)