Skip to content

Commit 3b17c94

Browse files
[fixup] Test tweaks for better coverage
1 parent 5e66da1 commit 3b17c94

File tree

1 file changed

+67
-45
lines changed

1 file changed

+67
-45
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070

7171
// -----
7272

73-
// The shape of the memref and the vector don't match, but the vector is a
74-
// contiguous subset of the memref, so "flattenable".
73+
// The shape of the memref and the vector don't match, but the vector,
74+
// ignoring the unit dimensions, is a contiguous subset of the memref,
75+
// so "flattenable"
7576

76-
func.func @transfer_read_dims_mismatch_contiguous(
77+
func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
7778
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
7879

7980
%c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
8384
return %res : vector<1x1x2x2xi8>
8485
}
8586

86-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
87+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
8788
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8889
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
8990
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
9293
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9394
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9495

95-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
97+
// CHECK-128B: memref.collapse_shape
98+
99+
// -----
100+
101+
// The shape of the memref and the vector don't match, but the vector is a
102+
// contiguous subset of the memref, so "flattenable"
103+
104+
func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
105+
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
106+
107+
%c0 = arith.constant 0 : index
108+
%cst = arith.constant 0 : i8
109+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
110+
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8>
111+
return %res : vector<2x3x2xi8>
112+
}
113+
114+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
115+
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
116+
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
118+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119+
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121+
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122+
// CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
123+
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124+
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
125+
126+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
96127
// CHECK-128B: memref.collapse_shape
97128

98129
// -----
@@ -380,7 +411,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
380411

381412
// -----
382413

383-
func.func @transfer_write_dims_mismatch_contiguous(
414+
func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
384415
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
385416
%vec : vector<1x1x2x2xi8>) {
386417

@@ -390,15 +421,41 @@ func.func @transfer_write_dims_mismatch_contiguous(
390421
return
391422
}
392423

393-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
424+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
394425
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
395426
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
396427
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
397428
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
398429
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
399430
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
400431

401-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
432+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
433+
// CHECK-128B: memref.collapse_shape
434+
435+
// -----
436+
437+
func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
438+
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
439+
%vec : vector<2x2xi8>) {
440+
441+
%c0 = arith.constant 0 : index
442+
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
443+
vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
444+
return
445+
}
446+
447+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
448+
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
449+
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
450+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
451+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
452+
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
453+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
454+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
455+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
456+
// CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
457+
458+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
402459
// CHECK-128B: memref.collapse_shape
403460

404461
// -----
@@ -620,6 +677,7 @@ func.func @negative_out_of_bound_transfer_read(
620677
}
621678
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
622679
// CHECK-NOT: memref.collapse_shape
680+
// CHECK-NOT: vector.shape_cast
623681

624682
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
625683
// CHECK-128B-NOT: memref.collapse_shape
@@ -638,45 +696,9 @@ func.func @negative_out_of_bound_transfer_write(
638696
}
639697
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
640698
// CHECK-NOT: memref.collapse_shape
699+
// CHECK-NOT: vector.shape_cast
641700

642701
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
643702
// CHECK-128B-NOT: memref.collapse_shape
644703
// CHECK-128B-NOT: vector.shape_cast
645704

646-
// -----
647-
648-
func.func @discontig_mem_contig_slice(
649-
%mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
650-
%c0 = arith.constant 0 : index
651-
vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
652-
vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
653-
return
654-
}
655-
656-
// CHECK-LABEL: func.func @discontig_mem_contig_slice
657-
// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
658-
// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
659-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
660-
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
661-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
662-
// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
663-
664-
// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
665-
// CHECK-128B-NOT: vector.shape_cast
666-
667-
// -----
668-
669-
func.func @discontig_mem_discontig_slice(
670-
%mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
671-
%c0 = arith.constant 0 : index
672-
vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
673-
vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
674-
return
675-
}
676-
677-
// CHECK-LABEL: func.func @discontig_mem_discontig_slice
678-
// CHECK-NOT: vector.shape_cast
679-
680-
// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
681-
// CHECK-128B-NOT: vector.shape_cast
682-

0 commit comments

Comments
 (0)