Skip to content

Commit e458434

Browse files
authored
[mlir][vector] Restrict narrow-type-emulation patterns (#115612)
All patterns in populateVectorNarrowTypeEmulationPatterns currently assume a 1-D vector load/store rather than an n-D vector load/store. This assumption is evident in ConvertVectorTransferRead, for example, here (extracted from `ConvertVectorTransferRead`): ```cpp auto newRead = rewriter.create<vector::TransferReadOp>( loc, VectorType::get(numElements, newElementType), adaptor.getSource(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); auto bitCast = rewriter.create<vector::BitCastOp>( loc, VectorType::get(numElements * scale, oldElementType), newRead); ``` Both invocations of `VectorType::get()` here generate a 1-D vector. Attempts to use these patterns with more generic cases, such as 2-D vectors, fail. For example, trying to cast the following 2-D case to `i32`: ```mlir func.func @vector_maskedload_2d_i8_negative( %idx1: index, %idx2: index, %num_elems: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> { %0 = memref.alloc() : memref<3x4xi8> %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1> %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru : memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8> return %1 : vector<2x4xi8> } ``` For example, casting to i32 produces: ```bash error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru : ^ ``` Instead of reworking these patterns (that's going to require much more effort), I’ve marked them as 1-D only and extended "TestEmulateNarrowTypePass" with an option to disable the Memref type converter - that's to be able to add negative tests (otherwise, the type converter throws an error we can't really test for). While not ideal, this workaround should suit a test pass.
1 parent ba572ab commit e458434

File tree

3 files changed

+146
-1
lines changed

3 files changed

+146
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
249249
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
250250
ConversionPatternRewriter &rewriter) const override {
251251

252+
// See #115653
253+
if (op.getValueToStore().getType().getRank() != 1)
254+
return rewriter.notifyMatchFailure(op,
255+
"only 1-D vectors are supported ATM");
256+
252257
auto loc = op.getLoc();
253258
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
254259
Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -315,6 +320,11 @@ struct ConvertVectorMaskedStore final
315320
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
316321
ConversionPatternRewriter &rewriter) const override {
317322

323+
// See #115653
324+
if (op.getValueToStore().getType().getRank() != 1)
325+
return rewriter.notifyMatchFailure(op,
326+
"only 1-D vectors are supported ATM");
327+
318328
auto loc = op.getLoc();
319329
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
320330
Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -418,6 +428,11 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
418428
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
419429
ConversionPatternRewriter &rewriter) const override {
420430

431+
// See #115653
432+
if (op.getVectorType().getRank() != 1)
433+
return rewriter.notifyMatchFailure(op,
434+
"only 1-D vectors are supported ATM");
435+
421436
auto loc = op.getLoc();
422437
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
423438
Type oldElementType = op.getType().getElementType();
@@ -517,6 +532,11 @@ struct ConvertVectorMaskedLoad final
517532
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
518533
ConversionPatternRewriter &rewriter) const override {
519534

535+
// See #115653
536+
if (op.getVectorType().getRank() != 1)
537+
return rewriter.notifyMatchFailure(op,
538+
"only 1-D vectors are supported ATM");
539+
520540
auto loc = op.getLoc();
521541
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
522542
Type oldElementType = op.getType().getElementType();
@@ -674,6 +694,11 @@ struct ConvertVectorTransferRead final
674694
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
675695
ConversionPatternRewriter &rewriter) const override {
676696

697+
// See #115653
698+
if (op.getVectorType().getRank() != 1)
699+
return rewriter.notifyMatchFailure(op,
700+
"only 1-D vectors are supported ATM");
701+
677702
auto loc = op.getLoc();
678703
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
679704
Type oldElementType = op.getType().getElementType();
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s
2+
3+
// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
4+
// insted of 1-D vectors. That's currently not supported.
5+
6+
///----------------------------------------------------------------------------------------
7+
/// vector.load
8+
///----------------------------------------------------------------------------------------
9+
10+
func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
11+
%0 = memref.alloc() : memref<3x4xi8>
12+
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
13+
return %1 : vector<2x4xi8>
14+
}
15+
16+
// No support for loading 2D vectors - expect no conversions
17+
// CHECK-LABEL: func @vector_load_2d_i8_negative
18+
// CHECK: memref.alloc() : memref<3x4xi8>
19+
// CHECK-NOT: i32
20+
21+
// -----
22+
23+
///----------------------------------------------------------------------------------------
24+
/// vector.transfer_read
25+
///----------------------------------------------------------------------------------------
26+
27+
func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
28+
%c0 = arith.constant 0 : i4
29+
%0 = memref.alloc() : memref<3x8xi4>
30+
%1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
31+
memref<3x8xi4>, vector<2x8xi4>
32+
return %1 : vector<2x8xi4>
33+
}
34+
// CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
35+
// CHECK: memref.alloc() : memref<3x8xi4>
36+
// CHECK-NOT: i32
37+
38+
// -----
39+
40+
///----------------------------------------------------------------------------------------
41+
/// vector.maskedload
42+
///----------------------------------------------------------------------------------------
43+
44+
func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
45+
%0 = memref.alloc() : memref<3x4xi8>
46+
%mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
47+
%1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
48+
memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
49+
return %1 : vector<2x4xi8>
50+
}
51+
52+
// CHECK-LABEL: func @vector_maskedload_2d_i8_negative
53+
// CHECK: memref.alloc() : memref<3x4xi8>
54+
// CHECK-NOT: i32
55+
56+
// -----
57+
58+
///----------------------------------------------------------------------------------------
59+
/// vector.extract -> vector.masked_load
60+
///----------------------------------------------------------------------------------------
61+
62+
func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
63+
%0 = memref.alloc() : memref<8x8x16xi4>
64+
%c0 = arith.constant 0 : index
65+
%c16 = arith.constant 16 : index
66+
%c8 = arith.constant 8 : index
67+
%cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
68+
%cst_2 = arith.constant dense<0> : vector<8x16xi4>
69+
%27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
70+
%48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
71+
%50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4>
72+
%63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4>
73+
return %63 : vector<8x8x16xi4>
74+
}
75+
76+
// CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
77+
// CHECK: memref.alloc() : memref<8x8x16xi4>
78+
// CHECK-NOT: i32
79+
80+
// -----
81+
82+
///----------------------------------------------------------------------------------------
83+
/// vector.store
84+
///----------------------------------------------------------------------------------------
85+
86+
func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg2: index) {
87+
%0 = memref.alloc() : memref<4x8xi8>
88+
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<2x8xi8>
89+
return
90+
}
91+
92+
// CHECK-LABEL: func @vector_store_2d_i8_negative
93+
// CHECK: memref.alloc() : memref<4x8xi8>
94+
// CHECK-NOT: i32
95+
96+
// -----
97+
98+
///----------------------------------------------------------------------------------------
99+
/// vector.maskedstore
100+
///----------------------------------------------------------------------------------------
101+
102+
func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) {
103+
%0 = memref.alloc() : memref<3x8xi8>
104+
%mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1>
105+
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8>
106+
return
107+
}
108+
109+
// CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
110+
// CHECK: memref.alloc() : memref<3x8xi8>
111+
// CHECK-NOT: i32

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ struct TestEmulateNarrowTypePass
7878
IntegerType::get(ty.getContext(), arithComputeBitwidth));
7979
});
8080

81-
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
81+
// With the type converter enabled, we are effectively unable to write
82+
// negative tests. This is a workaround specifically for negative tests.
83+
if (!disableMemrefTypeConversion)
84+
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
85+
8286
ConversionTarget target(*ctx);
8387
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
8488
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
@@ -109,6 +113,11 @@ struct TestEmulateNarrowTypePass
109113
Option<unsigned> arithComputeBitwidth{
110114
*this, "arith-compute-bitwidth",
111115
llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
116+
117+
Option<bool> disableMemrefTypeConversion{
118+
*this, "skip-memref-type-conversion",
119+
llvm::cl::desc("disable memref type conversion (to test failures)"),
120+
llvm::cl::init(false)};
112121
};
113122
} // namespace
114123

0 commit comments

Comments
 (0)