Skip to content

Commit 1740d45

Browse files
[fixup] Add/change a few tests
1 parent 4a1b0ef commit 1740d45

File tree

1 file changed

+130
-67
lines changed

1 file changed

+130
-67
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 130 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070

7171
// -----
7272

73-
// The shape of the memref and the vector don't match, but the vector is a
74-
// contiguous subset of the memref, so "flattenable". The leading unit dimensions
75-
// of the vector have no effect on the memref area read even if they
76-
// span a non-contiguous part of the memref.
77-
78-
func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
79-
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
80-
81-
%c0 = arith.constant 0 : index
82-
%cst = arith.constant 0 : i8
83-
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
84-
memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
85-
return %res : vector<1x1x2x2xi8>
86-
}
87-
88-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
89-
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
90-
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
91-
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
92-
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
93-
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
94-
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
95-
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
96-
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
97-
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
98-
99-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
100-
// CHECK-128B: memref.collapse_shape
101-
102-
// -----
103-
10473
// The shape of the memref and the vector don't match, but the vector is a
10574
// contiguous subset of the memref, so "flattenable"
10675

107-
func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
76+
func.func @transfer_read_dims_mismatch_contiguous(
10877
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
10978

11079
%c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
11483
return %res : vector<2x3x2xi8>
11584
}
11685

117-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
86+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
11887
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
11988
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
12089
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
12695
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
12796
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
12897

129-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
98+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
99+
// CHECK-128B: memref.collapse_shape
100+
101+
// -----
102+
103+
// The shape of the memref and the vector don't match, but the mismatch is only
104+
// at the leading unit dimensions of the vector.
105+
106+
func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
107+
%mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>) -> vector<1x1x4x3x2xi8> {
108+
109+
%c0 = arith.constant 0 : index
110+
%cst = arith.constant 0 : i8
111+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0, %c0], %cst :
112+
memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>, vector<1x1x4x3x2xi8>
113+
return %res : vector<1x1x4x3x2xi8>
114+
}
115+
116+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
117+
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
118+
// CHECK-SAME: -> vector<1x1x4x3x2xi8>
119+
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
121+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
122+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
123+
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
124+
// CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
125+
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
126+
// CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
127+
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
128+
// CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
129+
130+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
131+
// CHECK-128B: memref.collapse_shape
132+
133+
// -----
134+
135+
// The memref is non-contiguous, but the vector is a contiguous subset of the
136+
// memref, so "flattenable". The leading unit dimensions of the vector have no
137+
// effect on the memref area read even if they span the non-contiguous part of
138+
// the memref.
139+
140+
func.func @transfer_read_non_contiguous_unit_dims(
141+
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
142+
143+
%c0 = arith.constant 0 : index
144+
%cst = arith.constant 0 : i8
145+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
146+
memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x3x2xi8>
147+
return %res : vector<1x1x3x2xi8>
148+
}
149+
150+
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
151+
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
152+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
153+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
154+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
155+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
156+
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
157+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
158+
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
159+
// CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
160+
161+
// CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
130162
// CHECK-128B: memref.collapse_shape
131163

164+
132165
// -----
133166

134167
func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -418,61 +451,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
418451
// -----
419452

420453
// The shape of the memref and the vector don't match, but the vector is a
421-
// contiguous subset of the memref, so "flattenable". The leading unit dimensions
422-
// of the vector have no effect on the memref area written even if they
423-
// span a non-contiguous part of the memref.
454+
// contiguous subset of the memref, so "flattenable".
424455

425-
func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
426-
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
427-
%vec : vector<1x1x2x2xi8>) {
456+
func.func @transfer_write_dims_mismatch_contiguous(
457+
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
458+
%vec : vector<2x2xi8>) {
428459

429460
%c0 = arith.constant 0 : index
430461
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
431-
vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
462+
vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
432463
return
433464
}
434465

435-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
436-
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
437-
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
438-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
439-
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
466+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
467+
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
468+
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
469+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
470+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
440471
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
441-
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
442-
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
443-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
444-
// CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
472+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
473+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
474+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
475+
// CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
476+
477+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
478+
// CHECK-128B: memref.collapse_shape
479+
480+
// -----
481+
482+
// The shape of the memref and the vector don't match, but the mismatch is only
483+
// at the leading unit dimensions of the vector.
484+
485+
func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
486+
%mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>,
487+
%vec : vector<1x1x4x3x2xi8>) {
488+
489+
%c0 = arith.constant 0 : index
490+
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0, %c0] :
491+
vector<1x1x4x3x2xi8>, memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
492+
493+
return
494+
}
495+
496+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
497+
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
498+
// CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
499+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
500+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
501+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
502+
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
503+
// CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
504+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
505+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
506+
// CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
445507

446508
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
447509
// CHECK-128B: memref.collapse_shape
448510

449511
// -----
450512

451-
// The shape of the memref and the vector don't match, but the vector is a
452-
// contiguous subset of the memref, so "flattenable".
513+
// The memref is non-contiguous, but the vector is a contiguous subset of the
514+
// memref, so "flattenable". The leading unit dimensions of the vector have no
515+
// effect on the memref area read even if they span the non-contiguous part of
516+
// the memref.
453517

454-
func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
455-
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
456-
%vec : vector<2x2xi8>) {
518+
func.func @transfer_write_non_contiguous_unit_dims(
519+
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
520+
%vec : vector<1x1x3x2xi8>) {
457521

458522
%c0 = arith.constant 0 : index
459523
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
460-
vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
524+
vector<1x1x3x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
461525
return
462526
}
463527

464-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
465-
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
466-
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
467-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
468-
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
528+
// CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
529+
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
530+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x3x2xi8>) {
531+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
532+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
469533
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
470-
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
471-
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
472-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
473-
// CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
534+
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
535+
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8> to vector<6xi8>
536+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
537+
// CHECK-SAME: {in_bounds = [true]} : vector<6xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
474538

475-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
539+
// CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims(
476540
// CHECK-128B: memref.collapse_shape
477541

478542
// -----
@@ -718,4 +782,3 @@ func.func @negative_out_of_bound_transfer_write(
718782
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
719783
// CHECK-128B-NOT: memref.collapse_shape
720784
// CHECK-128B-NOT: vector.shape_cast
721-

0 commit comments

Comments
 (0)