15
15
using namespace mlir ;
16
16
using namespace mlir ::memref;
17
17
18
+ //
19
+ // Test the correctness of `memref::getNumContiguousTrailingDims`
20
+ //
18
21
TEST (MemRefLayout, numContigDim) {
19
22
MLIRContext ctx;
20
23
OpBuilder b (&ctx);
@@ -25,79 +28,108 @@ TEST(MemRefLayout, numContigDim) {
25
28
return StridedLayoutAttr::get (&ctx, 0 , s);
26
29
};
27
30
28
- // memref<2x2x2xf32, strided<[4,2,1]>
31
+ // Create a sequence of test cases, starting with the base case of a
32
+ // contiguous 2x2x2 memref with fixed dimensions and then at each step
33
+ // introducing one dynamic dimension starting from the right.
34
+ // With thus obtained memref, start with maximally contiguous strides
35
+ // and then at each step gradually introduce discontinuity by increasing
36
+ // a fixed stride size from the left to right.
37
+
38
+ // In these and the following test cases the intent is to achieve code
39
+ // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
40
+
41
+ // memref<2x2x2xf32, strided<[4,2,1]>>
29
42
auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
30
43
EXPECT_EQ (m1.getNumContiguousTrailingDims (), 3 );
31
44
32
- // memref<2x2x2xf32, strided<[8,2,1]>
45
+ // memref<2x2x2xf32, strided<[8,2,1]>>
33
46
auto m2 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
34
47
EXPECT_EQ (m2.getNumContiguousTrailingDims (), 2 );
35
48
36
- // memref<2x2x2xf32, strided<[8,4,1]>
49
+ // memref<2x2x2xf32, strided<[8,4,1]>>
37
50
auto m3 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
38
51
EXPECT_EQ (m3.getNumContiguousTrailingDims (), 1 );
39
52
40
- // memref<2x2x2xf32, strided<[8,4,2]>
53
+ // memref<2x2x2xf32, strided<[8,4,2]>>
41
54
auto m4 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
42
55
EXPECT_EQ (m4.getNumContiguousTrailingDims (), 0 );
43
56
44
- // memref<2x2x?xf32, strided<[?,?,1]>
57
+ // memref<2x2x?xf32, strided<[?,?,1]>>
45
58
auto m5 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 1 }));
46
59
EXPECT_EQ (m5.getNumContiguousTrailingDims (), 1 );
47
60
48
- // memref<2x2x?xf32, strided<[?,?,2]>
61
+ // memref<2x2x?xf32, strided<[?,?,2]>>
49
62
auto m6 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 2 }));
50
63
EXPECT_EQ (m6.getNumContiguousTrailingDims (), 0 );
51
64
52
- // memref<2x?x2xf32, strided<[?,2,1]>
65
+ // memref<2x?x2xf32, strided<[?,2,1]>>
53
66
auto m7 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
54
67
EXPECT_EQ (m7.getNumContiguousTrailingDims (), 2 );
55
68
56
- // memref<2x?x2xf32, strided<[?,4,1]>
69
+ // memref<2x?x2xf32, strided<[?,4,1]>>
57
70
auto m8 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 1 }));
58
71
EXPECT_EQ (m8.getNumContiguousTrailingDims (), 1 );
59
72
60
- // memref<2x?x2xf32, strided<[?,4,2]>
73
+ // memref<2x?x2xf32, strided<[?,4,2]>>
61
74
auto m9 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 2 }));
62
75
EXPECT_EQ (m9.getNumContiguousTrailingDims (), 0 );
63
76
64
- // memref<?x2x2xf32, strided<[4,2,1]>
77
+ // memref<?x2x2xf32, strided<[4,2,1]>>
65
78
auto m10 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
66
79
EXPECT_EQ (m10.getNumContiguousTrailingDims (), 3 );
67
80
68
- // memref<?x2x2xf32, strided<[8,2,1]>
81
+ // memref<?x2x2xf32, strided<[8,2,1]>>
69
82
auto m11 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
70
83
EXPECT_EQ (m11.getNumContiguousTrailingDims (), 2 );
71
84
72
- // memref<?x2x2xf32, strided<[8,4,1]>
85
+ // memref<?x2x2xf32, strided<[8,4,1]>>
73
86
auto m12 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
74
87
EXPECT_EQ (m12.getNumContiguousTrailingDims (), 1 );
75
88
76
- // memref<?x2x2xf32, strided<[8,4,2]>
89
+ // memref<?x2x2xf32, strided<[8,4,2]>>
77
90
auto m13 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
78
91
EXPECT_EQ (m13.getNumContiguousTrailingDims (), 0 );
79
92
80
- // memref<2x2x1xf32, strided<[2,1,2]>
93
+ //
94
+ // Repeat a similar process, but this time introduce a unit memref dimension
95
+ // to test that strides corresponding to unit dimensions are immaterial, even
96
+ // if dynamic.
97
+ //
98
+
99
+ // memref<2x2x1xf32, strided<[2,1,2]>>
81
100
auto m14 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({2 , 1 , 2 }));
82
101
EXPECT_EQ (m14.getNumContiguousTrailingDims (), 3 );
83
102
84
- // memref<2x2x1xf32, strided<[2,1,?]>
103
+ // memref<2x2x1xf32, strided<[2,1,?]>>
85
104
auto m15 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({2 , 1 , _}));
86
105
EXPECT_EQ (m15.getNumContiguousTrailingDims (), 3 );
87
106
88
- // memref<2x2x1xf32, strided<[4,2,2]>
107
+ // memref<2x2x1xf32, strided<[4,2,2]>>
89
108
auto m16 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({4 , 2 , 2 }));
90
109
EXPECT_EQ (m16.getNumContiguousTrailingDims (), 1 );
91
110
92
- // memref<2x1x2xf32, strided<[2,4,1]>
111
+ // memref<2x1x2xf32, strided<[2,4,1]>>
93
112
auto m17 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , 4 , 1 }));
94
113
EXPECT_EQ (m17.getNumContiguousTrailingDims (), 3 );
95
114
96
- // memref<2x1x2xf32, strided<[2,?,1]>
115
+ // memref<2x1x2xf32, strided<[2,?,1]>>
97
116
auto m18 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , _, 1 }));
98
117
EXPECT_EQ (m18.getNumContiguousTrailingDims (), 3 );
118
+
119
+ //
120
+ // Special case for identity maps and no explicit `strided` attribute - the
121
+ // memref is entirely contiguous even if the strides cannot be determined
122
+ // statically.
123
+ //
124
+
125
+ // memref<?x?x?xf32>
126
+ auto m19 = MemRefType::get ({_, _, _}, f32 );
127
+ EXPECT_EQ (m19.getNumContiguousTrailingDims (), 3 );
99
128
}
100
129
130
+ //
131
+ // Test the member function `memref::areTrailingDimsContiguous`
132
+ //
101
133
TEST (MemRefLayout, contigTrailingDim) {
102
134
MLIRContext ctx;
103
135
OpBuilder b (&ctx);
@@ -108,103 +140,18 @@ TEST(MemRefLayout, contigTrailingDim) {
108
140
return StridedLayoutAttr::get (&ctx, 0 , s);
109
141
};
110
142
111
- // memref<2x2x2xf32, strided<[4,2,1]>
112
- auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
113
- EXPECT_TRUE (m1.areTrailingDimsContiguous (1 ));
114
- EXPECT_TRUE (m1.areTrailingDimsContiguous (2 ));
115
- EXPECT_TRUE (m1.areTrailingDimsContiguous (3 ));
116
-
117
- // memref<2x2x2xf32, strided<[8,2,1]>
118
- auto m2 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
119
- EXPECT_TRUE (m2.areTrailingDimsContiguous (1 ));
120
- EXPECT_TRUE (m2.areTrailingDimsContiguous (2 ));
121
- EXPECT_FALSE (m2.areTrailingDimsContiguous (3 ));
122
-
123
- // memref<2x2x2xf32, strided<[8,4,1]>
124
- auto m3 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
125
- EXPECT_TRUE (m3.areTrailingDimsContiguous (1 ));
126
- EXPECT_FALSE (m3.areTrailingDimsContiguous (2 ));
127
- EXPECT_FALSE (m3.areTrailingDimsContiguous (3 ));
128
-
129
- // memref<2x2x2xf32, strided<[8,4,2]>
130
- auto m4 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
131
- EXPECT_FALSE (m4.areTrailingDimsContiguous (1 ));
132
- EXPECT_FALSE (m4.areTrailingDimsContiguous (2 ));
133
- EXPECT_FALSE (m4.areTrailingDimsContiguous (3 ));
134
-
135
- // memref<2x2x?xf32, strided<[?,?,1]>
136
- auto m5 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 1 }));
137
- EXPECT_TRUE (m5.areTrailingDimsContiguous (1 ));
138
- EXPECT_FALSE (m5.areTrailingDimsContiguous (2 ));
139
- EXPECT_FALSE (m5.areTrailingDimsContiguous (3 ));
140
-
141
- // memref<2x2x?xf32, strided<[?,?,2]>
142
- auto m6 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 2 }));
143
- EXPECT_FALSE (m6.areTrailingDimsContiguous (1 ));
144
- EXPECT_FALSE (m6.areTrailingDimsContiguous (2 ));
145
- EXPECT_FALSE (m6.areTrailingDimsContiguous (3 ));
146
-
147
- // memref<2x?x2xf32, strided<[?,2,1]>
148
- auto m7 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
149
- EXPECT_TRUE (m7.areTrailingDimsContiguous (1 ));
150
- EXPECT_TRUE (m7.areTrailingDimsContiguous (2 ));
151
- EXPECT_FALSE (m7.areTrailingDimsContiguous (3 ));
152
-
153
- // memref<2x?x2xf32, strided<[?,4,1]>
154
- auto m8 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 1 }));
155
- EXPECT_TRUE (m8.areTrailingDimsContiguous (1 ));
156
- EXPECT_FALSE (m8.areTrailingDimsContiguous (2 ));
157
- EXPECT_FALSE (m8.areTrailingDimsContiguous (3 ));
158
-
159
- // memref<2x?x2xf32, strided<[?,4,2]>
160
- auto m9 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 2 }));
161
- EXPECT_FALSE (m9.areTrailingDimsContiguous (1 ));
162
- EXPECT_FALSE (m9.areTrailingDimsContiguous (2 ));
163
- EXPECT_FALSE (m9.areTrailingDimsContiguous (3 ));
164
-
165
- // memref<?x2x2xf32, strided<[4,2,1]>
166
- auto m10 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
167
- EXPECT_TRUE (m10.areTrailingDimsContiguous (1 ));
168
- EXPECT_TRUE (m10.areTrailingDimsContiguous (2 ));
169
- EXPECT_TRUE (m10.areTrailingDimsContiguous (3 ));
170
-
171
- // memref<?x2x2xf32, strided<[8,2,1]>
172
- auto m11 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
173
- EXPECT_TRUE (m11.areTrailingDimsContiguous (1 ));
174
- EXPECT_TRUE (m11.areTrailingDimsContiguous (2 ));
175
- EXPECT_FALSE (m11.areTrailingDimsContiguous (3 ));
176
-
177
- // memref<?x2x2xf32, strided<[8,4,1]>
178
- auto m12 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
179
- EXPECT_TRUE (m12.areTrailingDimsContiguous (1 ));
180
- EXPECT_FALSE (m12.areTrailingDimsContiguous (2 ));
181
- EXPECT_FALSE (m12.areTrailingDimsContiguous (3 ));
143
+ // Pick up a random test case among the ones already present in the file and
144
+ // ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
145
+ // returned by `getNumContiguousTrailingDims` and `false` from that point on
146
+ // up to the memref rank.
182
147
183
- // memref<?x2x2xf32, strided<[8,4,2]>
184
- auto m13 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
185
- EXPECT_FALSE (m13.areTrailingDimsContiguous (1 ));
186
- EXPECT_FALSE (m13.areTrailingDimsContiguous (2 ));
187
- EXPECT_FALSE (m13.areTrailingDimsContiguous (3 ));
188
- }
189
-
190
- TEST (MemRefLayout, identityMaps) {
191
- MLIRContext ctx;
192
- OpBuilder b (&ctx);
148
+ // memref<2x?x2xf32, strided<[?,2,1]>>
149
+ auto m = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
150
+ int64_t n = m.getNumContiguousTrailingDims ();
151
+ for (int64_t i = 0 ; i <= n; ++i)
152
+ EXPECT_TRUE (m.areTrailingDimsContiguous (i));
193
153
194
- const int64_t _ = ShapedType::kDynamic ;
195
- const FloatType f32 = b.getF32Type ();
196
-
197
- // memref<2x2x2xf32>
198
- auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 );
199
- EXPECT_EQ (m1.getNumContiguousTrailingDims (), 3 );
200
- EXPECT_TRUE (m1.areTrailingDimsContiguous (1 ));
201
- EXPECT_TRUE (m1.areTrailingDimsContiguous (2 ));
202
- EXPECT_TRUE (m1.areTrailingDimsContiguous (3 ));
203
-
204
- // memref<?x?x?xf32>
205
- auto m2 = MemRefType::get ({_, _, _}, f32 );
206
- EXPECT_EQ (m2.getNumContiguousTrailingDims (), 3 );
207
- EXPECT_TRUE (m2.areTrailingDimsContiguous (1 ));
208
- EXPECT_TRUE (m2.areTrailingDimsContiguous (2 ));
209
- EXPECT_TRUE (m2.areTrailingDimsContiguous (3 ));
154
+ int64_t r = m.getRank ();
155
+ for (int64_t i = n + 1 ; i <= r; ++i)
156
+ EXPECT_FALSE (m.areTrailingDimsContiguous (i));
210
157
}
0 commit comments