Skip to content

Commit fe4444f

Browse files
[fixup] Misc NFC changes
- rename `getMaxCollapsabelTrailingDims` to `getMaxContiguousTrailingDims` - new set of examples - remove redundant call to `isIdentify()` - make sure a variable type is visible on the declaration line - some micro-optimisation
1 parent 572d176 commit fe4444f

File tree

3 files changed

+36
-29
lines changed

3 files changed

+36
-29
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -838,19 +838,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
838838
///
839839
bool areTrailingDimsContiguous(int64_t n);
840840

841-
/// Return the maximum number of trailing dimensions that can be
842-
/// collapsed.
841+
/// Return the maximum number of trailing dimensions that are
842+
/// contiguous.
843843
///
844844
/// Examples:
845-
/// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
846-
/// trailing dimensions is 0
847-
/// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
845+
/// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
848846
/// trailing dimensions is 3
849-
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
850-
/// collapsable trailing dimensions is 2.
851-
/// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
852-
/// is 4.
853-
int64_t getMaxCollapsableTrailingDims();
847+
/// - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable
848+
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
849+
/// - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable
850+
/// trailing dimensions is 1 (dimension 1 is non-contiguous)
851+
/// - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable
852+
/// trailing dimensions is 0 (dimension 2 is non-contiguous)
853+
/// - memref<?x3x2xi8, strided<[6,2,1]>>, the number of collapsable
854+
/// trailing dimensions is 3
855+
/// - memref<?x3x2xi8, strided<[12,2,1]>>, the number of collapsable
856+
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
857+
/// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
858+
/// trailing dimensions is 2 (stride 0 is dynamic)
859+
int64_t getMaxContiguousTrailingDims();
854860

855861
/// Return a version of this type with identity layout if it can be
856862
/// determined statically that the layout is the canonical contiguous

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,11 +646,10 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
646646
}
647647

648648
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649-
return getLayout().isIdentity() ||
650-
getMaxCollapsableTrailingDims() >= std::min(n, getRank());
649+
return getMaxContiguousTrailingDims() >= std::min(n, getRank());
651650
}
652651

653-
int64_t MemRefType::getMaxCollapsableTrailingDims() {
652+
int64_t MemRefType::getMaxContiguousTrailingDims() {
654653
const int64_t n = getRank();
655654

656655
// memrefs with identity layout are entirely contiguous.
@@ -664,7 +663,7 @@ int64_t MemRefType::getMaxCollapsableTrailingDims() {
664663
if (!succeeded(getStridesAndOffset(strides, offset)))
665664
return 0;
666665

667-
auto shape = getShape();
666+
ArrayRef<int64_t> shape = getShape();
668667

669668
// A memref with dimensions `d0, d1, ..., dn-1` and strides
670669
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
@@ -674,6 +673,8 @@ int64_t MemRefType::getMaxCollapsableTrailingDims() {
674673
for (int64_t i = n - 1; i >= 0; --i) {
675674
if (strides[i] != dimProduct)
676675
return n - i - 1;
676+
if (shape[i] == 1)
677+
continue;
677678
if (shape[i] == ShapedType::kDynamic)
678679
return n - i;
679680
dimProduct *= shape[i];

mlir/unittests/Dialect/MemRef/LayoutTest.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,55 +27,55 @@ TEST(MemRefLayout, maxCollapseDim) {
2727

2828
// memref<2x2x2xf32, strided<[4,2,1]>
2929
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
30-
EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
30+
EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
3131

3232
// memref<2x2x2xf32, strided<[8,2,1]>
3333
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
34-
EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
34+
EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 2);
3535

3636
// memref<2x2x2xf32, strided<[8,4,1]>
3737
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
38-
EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
38+
EXPECT_EQ(m3.getMaxContiguousTrailingDims(), 1);
3939

4040
// memref<2x2x2xf32, strided<[8,4,2]>
4141
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
42-
EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
42+
EXPECT_EQ(m4.getMaxContiguousTrailingDims(), 0);
4343

4444
// memref<2x2x?xf32, strided<[?,?,1]>
4545
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
46-
EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
46+
EXPECT_EQ(m5.getMaxContiguousTrailingDims(), 1);
4747

4848
// memref<2x2x?xf32, strided<[?,?,2]>
4949
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
50-
EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
50+
EXPECT_EQ(m6.getMaxContiguousTrailingDims(), 0);
5151

5252
// memref<2x?x2xf32, strided<[?,2,1]>
5353
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
54-
EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
54+
EXPECT_EQ(m7.getMaxContiguousTrailingDims(), 2);
5555

5656
// memref<2x?x2xf32, strided<[?,4,1]>
5757
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
58-
EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
58+
EXPECT_EQ(m8.getMaxContiguousTrailingDims(), 1);
5959

6060
// memref<2x?x2xf32, strided<[?,4,2]>
6161
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
62-
EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
62+
EXPECT_EQ(m9.getMaxContiguousTrailingDims(), 0);
6363

6464
// memref<?x2x2xf32, strided<[4,2,1]>
6565
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
66-
EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
66+
EXPECT_EQ(m10.getMaxContiguousTrailingDims(), 3);
6767

6868
// memref<?x2x2xf32, strided<[8,2,1]>
6969
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
70-
EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
70+
EXPECT_EQ(m11.getMaxContiguousTrailingDims(), 2);
7171

7272
// memref<?x2x2xf32, strided<[8,4,1]>
7373
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
74-
EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
74+
EXPECT_EQ(m12.getMaxContiguousTrailingDims(), 1);
7575

7676
// memref<?x2x2xf32, strided<[8,4,2]>
7777
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
78-
EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
78+
EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
7979
}
8080

8181
TEST(MemRefLayout, contigTrailingDim) {
@@ -176,14 +176,14 @@ TEST(MemRefLayout, identityMaps) {
176176

177177
// memref<2x2x2xf32>
178178
auto m1 = MemRefType::get({2, 2, 2}, f32);
179-
EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
179+
EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
180180
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
181181
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
182182
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
183183

184184
// memref<?x?x?xf32>
185185
auto m2 = MemRefType::get({_, _, _}, f32);
186-
EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
186+
EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 3);
187187
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
188188
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
189189
EXPECT_TRUE(m2.areTrailingDimsContiguous(3));

0 commit comments

Comments
 (0)