Skip to content

Commit e3b6e21

Browse files
[fixup] Handle unit dimensions by ignoring the corresponding stride
1 parent fe4444f commit e3b6e21

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,12 +669,14 @@ int64_t MemRefType::getMaxContiguousTrailingDims() {
669669
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
670670
// if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
671671
// for `i` in `[k, n-1]`.
672+
// Ignore stride elements if the corresponding dimension is 1, as they are
673+
// of no consequence.
672674
int64_t dimProduct = 1;
673675
for (int64_t i = n - 1; i >= 0; --i) {
674-
if (strides[i] != dimProduct)
675-
return n - i - 1;
676676
if (shape[i] == 1)
677677
continue;
678+
if (strides[i] != dimProduct)
679+
return n - i - 1;
678680
if (shape[i] == ShapedType::kDynamic)
679681
return n - i;
680682
dimProduct *= shape[i];

mlir/unittests/Dialect/MemRef/LayoutTest.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
using namespace mlir;
1616
using namespace mlir::memref;
1717

18-
TEST(MemRefLayout, maxCollapseDim) {
18+
TEST(MemRefLayout, maxContigDim) {
1919
MLIRContext ctx;
2020
OpBuilder b(&ctx);
2121

@@ -76,6 +76,26 @@ TEST(MemRefLayout, maxCollapseDim) {
7676
// memref<?x2x2xf32, strided<[8,4,2]>
7777
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
7878
EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
79+
80+
// memref<2x2x1xf32, strided<[2,1,2]>
81+
auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
82+
EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3);
83+
84+
// memref<2x2x1xf32, strided<[2,1,?]>
85+
auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
86+
EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3);
87+
88+
// memref<2x2x1xf32, strided<[4,2,2]>
89+
auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
90+
EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1);
91+
92+
// memref<2x1x2xf32, strided<[2,4,1]>
93+
auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
94+
EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3);
95+
96+
// memref<2x1x2xf32, strided<[2,?,1]>
97+
auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
98+
EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3);
7999
}
80100

81101
TEST(MemRefLayout, contigTrailingDim) {

0 commit comments

Comments
 (0)