Skip to content

Commit 145b055

Browse files
[fixup] Address review comments
1 parent af235d9 commit 145b055

File tree

2 files changed

+74
-121
lines changed

2 files changed

+74
-121
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
210210

211211
// -----
212212

213-
// Can flatten the righmost dynamic dimension
213+
// When collapsing memref dimensions, we may include the rightmost dynamic
214+
// dimension (e.g., at position `k`) provided that the strides for dimensions
215+
// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k`
216+
// itself does not factor into this. (Here "strides" mean both explicit and
217+
// implied by identity map)
214218

215219
func.func @transfer_read_dynamic_dim_to_flatten(
216220
%idx_1: index,
@@ -486,8 +490,8 @@ func.func @transfer_write_leading_dynamic_dims(
486490
// CHECK-128B: memref.collapse_shape
487491

488492
// -----
489-
490-
// The vector could be a non-contiguous slice of the input
493+
494+
// The vector is a non-contiguous slice of the input
491495
// memref.
492496

493497
func.func @negative_transfer_write_dynamic_to_flatten(
@@ -509,7 +513,9 @@ func.func @negative_transfer_write_dynamic_to_flatten(
509513

510514
// -----
511515

512-
func.func @transfer_write_dynamic_to_flatten(
516+
// See the comment in front of @transfer_read_dynamic_dim_to_flatten.
517+
518+
func.func @transfer_write_dynamic_dim_to_flatten(
513519
%idx_1: index,
514520
%idx_2: index,
515521
%vec : vector<1x2x6xi32>,
@@ -524,7 +530,7 @@ func.func @transfer_write_dynamic_to_flatten(
524530

525531
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
526532

527-
// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
533+
// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
528534
// CHECK-SAME: %[[IDX_1:arg0]]: index
529535
// CHECK-SAME: %[[IDX_2:arg1]]: index
530536
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
@@ -539,7 +545,7 @@ func.func @transfer_write_dynamic_to_flatten(
539545
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
540546
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
541547

542-
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
548+
// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
543549
// CHECK-128B-NOT: memref.collapse_shape
544550

545551
// -----

mlir/unittests/IR/MemrefLayoutTest.cpp

Lines changed: 62 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
using namespace mlir;
1616
using namespace mlir::memref;
1717

18+
//
19+
// Test the correctness of `memref::getNumContiguousTrailingDims`
20+
//
1821
TEST(MemRefLayout, numContigDim) {
1922
MLIRContext ctx;
2023
OpBuilder b(&ctx);
@@ -25,79 +28,108 @@ TEST(MemRefLayout, numContigDim) {
2528
return StridedLayoutAttr::get(&ctx, 0, s);
2629
};
2730

28-
// memref<2x2x2xf32, strided<[4,2,1]>
31+
// Create a sequence of test cases, starting with the base case of a
32+
// contiguous 2x2x2 memref with fixed dimensions and then at each step
33+
// introducing one dynamic dimension starting from the right.
34+
// With thus obtained memref, start with maximally contiguous strides
35+
// and then at each step gradually introduce discontinuity by increasing
36+
// a fixed stride size from the left to right.
37+
38+
// In these and the following test cases the intent is to achieve code
39+
// coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
40+
41+
// memref<2x2x2xf32, strided<[4,2,1]>>
2942
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
3043
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
3144

32-
// memref<2x2x2xf32, strided<[8,2,1]>
45+
// memref<2x2x2xf32, strided<[8,2,1]>>
3346
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
3447
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
3548

36-
// memref<2x2x2xf32, strided<[8,4,1]>
49+
// memref<2x2x2xf32, strided<[8,4,1]>>
3750
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
3851
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
3952

40-
// memref<2x2x2xf32, strided<[8,4,2]>
53+
// memref<2x2x2xf32, strided<[8,4,2]>>
4154
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
4255
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
4356

44-
// memref<2x2x?xf32, strided<[?,?,1]>
57+
// memref<2x2x?xf32, strided<[?,?,1]>>
4558
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
4659
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
4760

48-
// memref<2x2x?xf32, strided<[?,?,2]>
61+
// memref<2x2x?xf32, strided<[?,?,2]>>
4962
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
5063
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
5164

52-
// memref<2x?x2xf32, strided<[?,2,1]>
65+
// memref<2x?x2xf32, strided<[?,2,1]>>
5366
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
5467
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
5568

56-
// memref<2x?x2xf32, strided<[?,4,1]>
69+
// memref<2x?x2xf32, strided<[?,4,1]>>
5770
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
5871
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
5972

60-
// memref<2x?x2xf32, strided<[?,4,2]>
73+
// memref<2x?x2xf32, strided<[?,4,2]>>
6174
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
6275
EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
6376

64-
// memref<?x2x2xf32, strided<[4,2,1]>
77+
// memref<?x2x2xf32, strided<[4,2,1]>>
6578
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
6679
EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
6780

68-
// memref<?x2x2xf32, strided<[8,2,1]>
81+
// memref<?x2x2xf32, strided<[8,2,1]>>
6982
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
7083
EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
7184

72-
// memref<?x2x2xf32, strided<[8,4,1]>
85+
// memref<?x2x2xf32, strided<[8,4,1]>>
7386
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
7487
EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
7588

76-
// memref<?x2x2xf32, strided<[8,4,2]>
89+
// memref<?x2x2xf32, strided<[8,4,2]>>
7790
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
7891
EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
7992

80-
// memref<2x2x1xf32, strided<[2,1,2]>
93+
//
94+
// Repeat a similar process, but this time introduce a unit memref dimension
95+
// to test that strides corresponding to unit dimensions are immaterial, even
96+
// if dynamic.
97+
//
98+
99+
// memref<2x2x1xf32, strided<[2,1,2]>>
81100
auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
82101
EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
83102

84-
// memref<2x2x1xf32, strided<[2,1,?]>
103+
// memref<2x2x1xf32, strided<[2,1,?]>>
85104
auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
86105
EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
87106

88-
// memref<2x2x1xf32, strided<[4,2,2]>
107+
// memref<2x2x1xf32, strided<[4,2,2]>>
89108
auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
90109
EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
91110

92-
// memref<2x1x2xf32, strided<[2,4,1]>
111+
// memref<2x1x2xf32, strided<[2,4,1]>>
93112
auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
94113
EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
95114

96-
// memref<2x1x2xf32, strided<[2,?,1]>
115+
// memref<2x1x2xf32, strided<[2,?,1]>>
97116
auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
98117
EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
118+
119+
//
120+
// Special case for identity maps and no explicit `strided` attribute - the
121+
// memref is entirely contiguous even if the strides cannot be determined
122+
// statically.
123+
//
124+
125+
// memref<?x?x?xf32>
126+
auto m19 = MemRefType::get({_, _, _}, f32);
127+
EXPECT_EQ(m19.getNumContiguousTrailingDims(), 3);
99128
}
100129

130+
//
131+
// Test the member function `memref::areTrailingDimsContiguous`
132+
//
101133
TEST(MemRefLayout, contigTrailingDim) {
102134
MLIRContext ctx;
103135
OpBuilder b(&ctx);
@@ -108,103 +140,18 @@ TEST(MemRefLayout, contigTrailingDim) {
108140
return StridedLayoutAttr::get(&ctx, 0, s);
109141
};
110142

111-
// memref<2x2x2xf32, strided<[4,2,1]>
112-
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
113-
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
114-
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
115-
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
116-
117-
// memref<2x2x2xf32, strided<[8,2,1]>
118-
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
119-
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
120-
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
121-
EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
122-
123-
// memref<2x2x2xf32, strided<[8,4,1]>
124-
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
125-
EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
126-
EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
127-
EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
128-
129-
// memref<2x2x2xf32, strided<[8,4,2]>
130-
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
131-
EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
132-
EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
133-
EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
134-
135-
// memref<2x2x?xf32, strided<[?,?,1]>
136-
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
137-
EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
138-
EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
139-
EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
140-
141-
// memref<2x2x?xf32, strided<[?,?,2]>
142-
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
143-
EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
144-
EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
145-
EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
146-
147-
// memref<2x?x2xf32, strided<[?,2,1]>
148-
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
149-
EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
150-
EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
151-
EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
152-
153-
// memref<2x?x2xf32, strided<[?,4,1]>
154-
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
155-
EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
156-
EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
157-
EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
158-
159-
// memref<2x?x2xf32, strided<[?,4,2]>
160-
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
161-
EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
162-
EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
163-
EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
164-
165-
// memref<?x2x2xf32, strided<[4,2,1]>
166-
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
167-
EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
168-
EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
169-
EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
170-
171-
// memref<?x2x2xf32, strided<[8,2,1]>
172-
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
173-
EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
174-
EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
175-
EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
176-
177-
// memref<?x2x2xf32, strided<[8,4,1]>
178-
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
179-
EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
180-
EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
181-
EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
143+
// Pick up a random test case among the ones already present in the file and
144+
// ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
145+
// returned by `getNumContiguousTrailingDims` and `false` from that point on
146+
// up to the memref rank.
182147

183-
// memref<?x2x2xf32, strided<[8,4,2]>
184-
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
185-
EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
186-
EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
187-
EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
188-
}
189-
190-
TEST(MemRefLayout, identityMaps) {
191-
MLIRContext ctx;
192-
OpBuilder b(&ctx);
148+
// memref<2x?x2xf32, strided<[?,2,1]>>
149+
auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
150+
int64_t n = m.getNumContiguousTrailingDims();
151+
for (int64_t i = 0; i <= n; ++i)
152+
EXPECT_TRUE(m.areTrailingDimsContiguous(i));
193153

194-
const int64_t _ = ShapedType::kDynamic;
195-
const FloatType f32 = b.getF32Type();
196-
197-
// memref<2x2x2xf32>
198-
auto m1 = MemRefType::get({2, 2, 2}, f32);
199-
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
200-
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
201-
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
202-
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
203-
204-
// memref<?x?x?xf32>
205-
auto m2 = MemRefType::get({_, _, _}, f32);
206-
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 3);
207-
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
208-
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
209-
EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
154+
int64_t r = m.getRank();
155+
for (int64_t i = n + 1; i <= r; ++i)
156+
EXPECT_FALSE(m.areTrailingDimsContiguous(i));
210157
}

0 commit comments

Comments
 (0)