@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
70
70
71
71
// -----
72
72
73
- // The shape of the memref and the vector don't match, but the vector is a
74
- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75
- // of the vector have no effect on the memref area read even if they
76
- // span a non-contiguous part of the memref.
77
-
78
- func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
79
- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
80
-
81
- %c0 = arith.constant 0 : index
82
- %cst = arith.constant 0 : i8
83
- %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
84
- memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
85
- return %res : vector <1 x1 x2 x2 xi8 >
86
- }
87
-
88
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
89
- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
90
- // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
91
- // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
92
- // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
93
- // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
94
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
95
- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
96
- // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
97
- // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
98
-
99
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
100
- // CHECK-128B: memref.collapse_shape
101
-
102
- // -----
103
-
104
73
// The shape of the memref and the vector don't match, but the vector is a
105
74
// contiguous subset of the memref, so "flattenable"
106
75
107
- func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
76
+ func.func @transfer_read_dims_mismatch_contiguous (
108
77
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
109
78
110
79
%c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
114
83
return %res : vector <2 x3 x2 xi8 >
115
84
}
116
85
117
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
86
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
118
87
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
119
88
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120
89
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
126
95
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
127
96
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
128
97
129
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
98
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
99
+ // CHECK-128B: memref.collapse_shape
100
+
101
+ // -----
102
+
103
+ // The shape of the memref and the vector don't match, but the mismatch is only
104
+ // at the leading unit dimensions of the vector.
105
+
106
+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
107
+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x4 x3 x2 xi8 > {
108
+
109
+ %c0 = arith.constant 0 : index
110
+ %cst = arith.constant 0 : i8
111
+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 , %c0 ], %cst :
112
+ memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x4 x3 x2 xi8 >
113
+ return %res : vector <1 x1 x4 x3 x2 xi8 >
114
+ }
115
+
116
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
117
+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
118
+ // CHECK-SAME: -> vector<1x1x4x3x2xi8>
119
+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
121
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
122
+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
123
+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
124
+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
125
+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
126
+ // CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
127
+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
128
+ // CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
129
+
130
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
131
+ // CHECK-128B: memref.collapse_shape
132
+
133
+ // -----
134
+
135
+ // The memref is non-contiguous, but the vector is a contiguous subset of the
136
+ // memref, so "flattenable". The leading unit dimensions of the vector have no
137
+ // effect on the memref area read even if they span the non-contiguous part of
138
+ // the memref.
139
+
140
+ func.func @transfer_read_non_contiguous_unit_dims (
141
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x3 x2 xi8 > {
142
+
143
+ %c0 = arith.constant 0 : index
144
+ %cst = arith.constant 0 : i8
145
+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
146
+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x3 x2 xi8 >
147
+ return %res : vector <1 x1 x3 x2 xi8 >
148
+ }
149
+
150
+ // CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
151
+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
152
+ // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
153
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
154
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
155
+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
156
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
157
+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
158
+ // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
159
+ // CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
160
+
161
+ // CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
130
162
// CHECK-128B: memref.collapse_shape
131
163
164
+
132
165
// -----
133
166
134
167
func.func @transfer_read_dims_mismatch_non_zero_indices (
@@ -418,61 +451,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
418
451
// -----
419
452
420
453
// The shape of the memref and the vector don't match, but the vector is a
421
- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
422
- // of the vector have no effect on the memref area written even if they
423
- // span a non-contiguous part of the memref.
454
+ // contiguous subset of the memref, so "flattenable".
424
455
425
- func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
426
- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
427
- %vec : vector <1 x 1 x 2 x 2 x i8 >) {
456
+ func.func @transfer_write_dims_mismatch_contiguous (
457
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
458
+ %vec : vector <2 x 2 x i8 >) {
428
459
429
460
%c0 = arith.constant 0 : index
430
461
vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
431
- vector <1 x 1 x 2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
462
+ vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
432
463
return
433
464
}
434
465
435
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
436
- // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
437
- // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x2x2xi8>) {
438
- // CHECK: %[[C0:.* ]] = arith.constant 0 : index
439
- // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
466
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
467
+ // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
468
+ // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
469
+ // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
470
+ // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
440
471
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
441
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
442
- // CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
443
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
444
- // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
472
+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
473
+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
474
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
475
+ // CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
476
+
477
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
478
+ // CHECK-128B: memref.collapse_shape
479
+
480
+ // -----
481
+
482
+ // The shape of the memref and the vector don't match, but the mismatch is only
483
+ // at the leading unit dimensions of the vector.
484
+
485
+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
486
+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>,
487
+ %vec : vector <1 x1 x4 x3 x2 xi8 >) {
488
+
489
+ %c0 = arith.constant 0 : index
490
+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 , %c0 ] :
491
+ vector <1 x1 x4 x3 x2 xi8 >, memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>
492
+
493
+ return
494
+ }
495
+
496
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
497
+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
498
+ // CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
499
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
500
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
501
+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
502
+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
503
+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
504
+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
505
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
506
+ // CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
445
507
446
508
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
447
509
// CHECK-128B: memref.collapse_shape
448
510
449
511
// -----
450
512
451
- // The shape of the memref and the vector don't match, but the vector is a
452
- // contiguous subset of the memref, so "flattenable".
513
+ // The memref is non-contiguous, but the vector is a contiguous subset of the
514
+ // memref, so "flattenable". The leading unit dimensions of the vector have no
515
+ // effect on the memref area read even if they span the non-contiguous part of
516
+ // the memref.
453
517
454
- func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
455
- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
456
- %vec : vector <2 x 2 x i8 >) {
518
+ func.func @transfer_write_non_contiguous_unit_dims (
519
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
520
+ %vec : vector <1 x 1 x 3 x 2 x i8 >) {
457
521
458
522
%c0 = arith.constant 0 : index
459
523
vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
460
- vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
524
+ vector <1 x 1 x 3 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
461
525
return
462
526
}
463
527
464
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
465
- // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
466
- // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
467
- // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
468
- // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
528
+ // CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
529
+ // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
530
+ // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x3x2xi8>) {
531
+ // CHECK: %[[C0:.* ]] = arith.constant 0 : index
532
+ // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
469
533
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
470
- // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}} >
471
- // CHECK: %[[VEC_1D:.+ ]] = vector.shape_cast %[[VEC]] : vector<2x2xi8 > to vector<4xi8 >
472
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM ]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
473
- // CHECK-SAME: : vector<4xi8 >, memref<5x4x6xi8, {{.+}} >
534
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
535
+ // CHECK: %[[VEC_1D:.* ]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8 > to vector<6xi8 >
536
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED ]][%[[C0]], %[[C0]], %[[C0]]]
537
+ // CHECK-SAME: {in_bounds = [true]} : vector<6xi8 >, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
474
538
475
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
539
+ // CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims (
476
540
// CHECK-128B: memref.collapse_shape
477
541
478
542
// -----
@@ -718,4 +782,3 @@ func.func @negative_out_of_bound_transfer_write(
718
782
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
719
783
// CHECK-128B-NOT: memref.collapse_shape
720
784
// CHECK-128B-NOT: vector.shape_cast
721
-
0 commit comments