1
1
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
2
+ // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
2
3
3
4
func.func @transfer_read_dims_match_contiguous (
4
5
%arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <5 x4 x3 x2 xi8 > {
@@ -16,6 +17,9 @@ func.func @transfer_read_dims_match_contiguous(
16
17
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
17
18
// CHECK: return %[[VEC2D]]
18
19
20
+ // CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous
21
+ // CHECK-128B: memref.collapse_shape
22
+
19
23
// -----
20
24
21
25
func.func @transfer_read_dims_match_contiguous_empty_stride (
@@ -27,13 +31,16 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
27
31
return %v : vector <5 x4 x3 x2 xi8 >
28
32
}
29
33
30
- // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride
34
+ // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
31
35
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
32
36
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
33
37
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
34
38
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
35
39
// CHECK: return %[[VEC2D]]
36
40
41
+ // CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
42
+ // CHECK-128B: memref.collapse_shape
43
+
37
44
// -----
38
45
39
46
// The shape of the memref and the vector don't match, but the vector is a
@@ -57,6 +64,9 @@ func.func @transfer_read_dims_mismatch_contiguous(
57
64
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
58
65
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
59
66
67
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
68
+ // CHECK-128B: memref.collapse_shape
69
+
60
70
// -----
61
71
62
72
func.func @transfer_read_dims_mismatch_non_zero_indices (
@@ -66,7 +76,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
66
76
%m_out: memref <1 x2 x6 xi32 >) {
67
77
%c0 = arith.constant 0 : index
68
78
%c0_i32 = arith.constant 0 : i32
69
- %2 = vector.transfer_read %m_in [%c0 , %idx_1 , %idx_2 , %c0 ], %c0_i32 {in_bounds = [true , true , true ]} :
79
+ %2 = vector.transfer_read %m_in [%c0 , %idx_1 , %idx_2 , %c0 ], %c0_i32 {in_bounds = [true , true , true ]} :
70
80
memref <1 x43 x4 x6 xi32 >, vector <1 x2 x6 xi32 >
71
81
vector.transfer_write %2 , %m_out [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
72
82
vector <1 x2 x6 xi32 >, memref <1 x2 x6 xi32 >
@@ -87,6 +97,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
87
97
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
88
98
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
89
99
100
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
101
+ // CHECK-128B-NOT: memref.collapse_shape
102
+
90
103
// -----
91
104
92
105
// The input memref has a dynamic trailing shape and hence is not flattened.
@@ -99,7 +112,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
99
112
%m_out: memref <1 x2 x6 xi32 >) {
100
113
%c0 = arith.constant 0 : index
101
114
%c0_i32 = arith.constant 0 : i32
102
- %2 = vector.transfer_read %m_in [%c0 , %idx_1 , %idx_2 , %c0 ], %c0_i32 {in_bounds = [true , true , true ]} :
115
+ %2 = vector.transfer_read %m_in [%c0 , %idx_1 , %idx_2 , %c0 ], %c0_i32 {in_bounds = [true , true , true ]} :
103
116
memref <1 x?x4 x6 xi32 >, vector <1 x2 x6 xi32 >
104
117
vector.transfer_write %2 , %m_out [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
105
118
vector <1 x2 x6 xi32 >, memref <1 x2 x6 xi32 >
@@ -115,6 +128,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
115
128
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
116
129
// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
117
130
131
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
132
+ // CHECK-128B-NOT: memref.collapse_shape
133
+
118
134
// -----
119
135
120
136
func.func @transfer_read_dims_mismatch_non_contiguous (
@@ -130,6 +146,9 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
130
146
// CHECK-NOT: memref.collapse_shape
131
147
// CHECK-NOT: vector.shape_cast
132
148
149
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
150
+ // CHECK-128B-NOT: memref.collapse_shape
151
+
133
152
// -----
134
153
135
154
func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride (
@@ -141,10 +160,13 @@ func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
141
160
return %v : vector <2 x1 x2 x2 xi8 >
142
161
}
143
162
144
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
163
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
145
164
// CHECK-NOT: memref.collapse_shape
146
165
// CHECK-NOT: vector.shape_cast
147
166
167
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
168
+ // CHECK-128B-NOT: memref.collapse_shape
169
+
148
170
// -----
149
171
150
172
func.func @transfer_write_dims_match_contiguous (
@@ -155,13 +177,16 @@ func.func @transfer_write_dims_match_contiguous(
155
177
return
156
178
}
157
179
158
- // CHECK-LABEL: func @transfer_write_dims_match_contiguous
180
+ // CHECK-LABEL: func @transfer_write_dims_match_contiguous(
159
181
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
160
182
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
161
183
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
162
184
// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
163
185
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
164
186
187
+ // CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous(
188
+ // CHECK-128B: memref.collapse_shape
189
+
165
190
// -----
166
191
167
192
func.func @transfer_write_dims_mismatch_contiguous (
@@ -182,6 +207,9 @@ func.func @transfer_write_dims_mismatch_contiguous(
182
207
// CHECK: return
183
208
// CHECK: }
184
209
210
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
211
+ // CHECK-128B: memref.collapse_shape
212
+
185
213
// -----
186
214
187
215
func.func @transfer_write_dims_mismatch_non_contiguous (
@@ -196,6 +224,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
196
224
// CHECK-NOT: memref.collapse_shape
197
225
// CHECK-NOT: vector.shape_cast
198
226
227
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
228
+ // CHECK-128B-NOT: memref.collapse_shape
229
+
199
230
// -----
200
231
201
232
func.func @transfer_write_0d (%arg : memref <i8 >, %vec : vector <i8 >) {
@@ -207,6 +238,10 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
207
238
// CHECK-NOT: memref.collapse_shape
208
239
// CHECK-NOT: vector.shape_cast
209
240
241
+ // CHECK-128B-LABEL: func @transfer_write_0d(
242
+ // CHECK-128B-NOT: memref.collapse_shape
243
+ // CHECK-128B-NOT: vector.shape_cast
244
+
210
245
// -----
211
246
212
247
func.func @transfer_read_0d (%arg : memref <i8 >) -> vector <i8 > {
@@ -219,6 +254,10 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
219
254
// CHECK-NOT: memref.collapse_shape
220
255
// CHECK-NOT: vector.shape_cast
221
256
257
+ // CHECK-128B-LABEL: func @transfer_read_0d(
258
+ // CHECK-128B-NOT: memref.collapse_shape
259
+ // CHECK-128B-NOT: vector.shape_cast
260
+
222
261
// -----
223
262
224
263
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices (%arg0 : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) -> vector <8 x4 xi8 > {
@@ -241,6 +280,9 @@ func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memre
241
280
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
242
281
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
243
282
283
+ // CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
284
+ // CHECK-128B: memref.collapse_shape
285
+
244
286
// -----
245
287
246
288
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices (%vec : vector <8 x4 xi8 >, %dst : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) {
@@ -260,6 +302,9 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
260
302
// CHECK-SAME: {in_bounds = [true]}
261
303
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
262
304
305
+ // CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
306
+ // CHECK-128B: memref.collapse_shape
307
+
263
308
// -----
264
309
265
310
func.func @transfer_read_flattenable_negative (
@@ -274,6 +319,9 @@ func.func @transfer_read_flattenable_negative(
274
319
// CHECK-LABEL: func @transfer_read_flattenable_negative
275
320
// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
276
321
322
+ // CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
323
+ // CHECK-128B-NOT: memref.collapse_shape
324
+
277
325
// -----
278
326
279
327
func.func @transfer_read_flattenable_negative2 (
@@ -288,6 +336,9 @@ func.func @transfer_read_flattenable_negative2(
288
336
// CHECK-LABEL: func @transfer_read_flattenable_negative2
289
337
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
290
338
339
+ // CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
340
+ // CHECK-128B-NOT: memref.collapse_shape
341
+
291
342
// -----
292
343
293
344
func.func @fold_unit_dim_add_basic (%arg0 : vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
@@ -302,6 +353,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
302
353
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
303
354
// CHECK: return %[[VAL_4]] : vector<1x8xi32>
304
355
356
+ // CHECK-128B-LABEL: func @fold_unit_dim_add_basic(
357
+ // CHECK-128B-NOT: memref.collapse_shape
358
+
305
359
// -----
306
360
307
361
func.func @fold_unit_dim_add_leading_and_trailing (%arg0 : vector <1 x8 x1 xi32 >) -> vector <1 x8 x1 xi32 > {
@@ -316,6 +370,9 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) ->
316
370
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32>
317
371
// CHECK: return %[[VAL_4]] : vector<1x8x1xi32>
318
372
373
+ // CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing(
374
+ // CHECK-128B-NOT: memref.collapse_shape
375
+
319
376
// -----
320
377
321
378
func.func @fold_unit_dim_add (%arg0 : vector <8 x1 xi32 >,
@@ -334,6 +391,9 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
334
391
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
335
392
// CHECK: return %[[VAL_4]] : vector<8xi32>
336
393
394
+ // CHECK-128B-LABEL: func @fold_unit_dim_add(
395
+ // CHECK-128B-NOT: memref.collapse_shape
396
+
337
397
// -----
338
398
339
399
func.func @fold_unit_dim_mulf (%arg0 : vector <8 x[2 ]x1 xf32 >,
@@ -352,6 +412,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
352
412
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
353
413
// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
354
414
415
+ // CHECK-128B-LABEL: func @fold_unit_dim_mulf(
416
+ // CHECK-128B-NOT: memref.collapse_shape
417
+
355
418
// -----
356
419
357
420
func.func @fold_unit_dim_sitofp (%arg0 : vector <8 x[2 ]x1 xi8 >) -> vector <8 x[2 ]xf32 > {
@@ -367,6 +430,9 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32>
367
430
// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
368
431
// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
369
432
433
+ // CHECK-128B-LABEL: func @fold_unit_dim_sitofp(
434
+ // CHECK-128B-NOT: memref.collapse_shape
435
+
370
436
// -----
371
437
372
438
// All shape casts are folded away
@@ -389,3 +455,7 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
389
455
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
390
456
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
391
457
// CHECK: return %[[VAL_4]] : vector<8xi32>
458
+
459
+ // CHECK-128B-LABEL: func @fold_unit_dims_entirely(
460
+ // CHECK-128B-NOT: memref.collapse_shape
461
+
0 commit comments