Skip to content

Commit 25f7bdd

Browse files
[fixup] Test tweaks for better coverage
1 parent dc05e13 commit 25f7bdd

File tree

1 file changed

+67
-45
lines changed

1 file changed

+67
-45
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070

7171
// -----
7272

73-
// The shape of the memref and the vector don't match, but the vector is a
74-
// contiguous subset of the memref, so "flattenable".
73+
// The shape of the memref and the vector don't match, but the vector,
74+
// ignoring the unit dimensions, is a contiguous subset of the memref,
75+
// so "flattenable"
7576

76-
func.func @transfer_read_dims_mismatch_contiguous(
77+
func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
7778
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
7879

7980
%c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
8384
return %res : vector<1x1x2x2xi8>
8485
}
8586

86-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
87+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
8788
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8889
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
8990
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
9293
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9394
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9495

95-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
97+
// CHECK-128B: memref.collapse_shape
98+
99+
// -----
100+
101+
// The shape of the memref and the vector don't match, but the vector is a
102+
// contiguous subset of the memref, so "flattenable"
103+
104+
func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
105+
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
106+
107+
%c0 = arith.constant 0 : index
108+
%cst = arith.constant 0 : i8
109+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
110+
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8>
111+
return %res : vector<2x3x2xi8>
112+
}
113+
114+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
115+
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
116+
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
118+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119+
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121+
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122+
// CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
123+
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124+
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
125+
126+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
96127
// CHECK-128B: memref.collapse_shape
97128

98129
// -----
@@ -384,7 +415,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
384415

385416
// -----
386417

387-
func.func @transfer_write_dims_mismatch_contiguous(
418+
func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
388419
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
389420
%vec : vector<1x1x2x2xi8>) {
390421

@@ -394,15 +425,41 @@ func.func @transfer_write_dims_mismatch_contiguous(
394425
return
395426
}
396427

397-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
428+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
398429
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
399430
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
400431
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
401432
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
402433
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
403434
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
404435

405-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
436+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
437+
// CHECK-128B: memref.collapse_shape
438+
439+
// -----
440+
441+
func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
442+
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
443+
%vec : vector<2x2xi8>) {
444+
445+
%c0 = arith.constant 0 : index
446+
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
447+
vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
448+
return
449+
}
450+
451+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
452+
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
453+
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
454+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
455+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
456+
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
457+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
458+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
459+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
460+
// CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
461+
462+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
406463
// CHECK-128B: memref.collapse_shape
407464

408465
// -----
@@ -626,6 +683,7 @@ func.func @negative_out_of_bound_transfer_read(
626683
}
627684
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
628685
// CHECK-NOT: memref.collapse_shape
686+
// CHECK-NOT: vector.shape_cast
629687

630688
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
631689
// CHECK-128B-NOT: memref.collapse_shape
@@ -642,45 +700,9 @@ func.func @negative_out_of_bound_transfer_write(
642700
}
643701
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
644702
// CHECK-NOT: memref.collapse_shape
703+
// CHECK-NOT: vector.shape_cast
645704

646705
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
647706
// CHECK-128B-NOT: memref.collapse_shape
648707
// CHECK-128B-NOT: vector.shape_cast
649708

650-
// -----
651-
652-
func.func @discontig_mem_contig_slice(
653-
%mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
654-
%c0 = arith.constant 0 : index
655-
vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
656-
vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
657-
return
658-
}
659-
660-
// CHECK-LABEL: func.func @discontig_mem_contig_slice
661-
// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
662-
// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
663-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
664-
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
665-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
666-
// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
667-
668-
// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
669-
// CHECK-128B-NOT: vector.shape_cast
670-
671-
// -----
672-
673-
func.func @discontig_mem_discontig_slice(
674-
%mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
675-
%c0 = arith.constant 0 : index
676-
vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
677-
vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
678-
return
679-
}
680-
681-
// CHECK-LABEL: func.func @discontig_mem_discontig_slice
682-
// CHECK-NOT: vector.shape_cast
683-
684-
// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
685-
// CHECK-128B-NOT: vector.shape_cast
686-

0 commit comments

Comments
 (0)