|
15 | 15 | using namespace mlir;
|
16 | 16 | using namespace mlir::memref;
|
17 | 17 |
|
18 |
| -TEST(MemRefLayout, maxCollapseDim) { |
| 18 | +TEST(MemRefLayout, maxContigDim) { |
19 | 19 | MLIRContext ctx;
|
20 | 20 | OpBuilder b(&ctx);
|
21 | 21 |
|
@@ -76,6 +76,26 @@ TEST(MemRefLayout, maxCollapseDim) {
|
76 | 76 | // memref<?x2x2xf32, strided<[8,4,2]>
|
77 | 77 | auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
|
78 | 78 | EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
|
| 79 | + |
| 80 | + // memref<2x2x1xf32, strided<[2,1,2]> |
| 81 | + auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2})); |
| 82 | + EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3); |
| 83 | + |
| 84 | + // memref<2x2x1xf32, strided<[2,1,?]> |
| 85 | + auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _})); |
| 86 | + EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3); |
| 87 | + |
| 88 | + // memref<2x2x1xf32, strided<[4,2,2]> |
| 89 | + auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2})); |
| 90 | + EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1); |
| 91 | + |
| 92 | + // memref<2x1x2xf32, strided<[2,4,1]> |
| 93 | + auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1})); |
| 94 | + EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3); |
| 95 | + |
| 96 | + // memref<2x1x2xf32, strided<[2,?,1]> |
| 97 | + auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1})); |
| 98 | + EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3); |
79 | 99 | }
|
80 | 100 |
|
81 | 101 | TEST(MemRefLayout, contigTrailingDim) {
|
|
0 commit comments