@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
70
70
71
71
// -----
72
72
73
- // The shape of the memref and the vector don't match, but the vector is a
74
- // contiguous subset of the memref, so "flattenable".
73
+ // The shape of the memref and the vector don't match, but the vector,
74
+ // ignoring the unit dimensions, is a contiguous subset of the memref,
75
+ // so "flattenable"
75
76
76
- func.func @transfer_read_dims_mismatch_contiguous (
77
+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
77
78
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
78
79
79
80
%c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
83
84
return %res : vector <1 x1 x2 x2 xi8 >
84
85
}
85
86
86
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
87
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
87
88
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
88
89
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
89
90
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
92
93
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
93
94
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
94
95
95
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
97
+ // CHECK-128B: memref.collapse_shape
98
+
99
+ // -----
100
+
101
+ // The shape of the memref and the vector don't match, but the vector is a
102
+ // contiguous subset of the memref, so "flattenable"
103
+
104
+ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
105
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
106
+
107
+ %c0 = arith.constant 0 : index
108
+ %cst = arith.constant 0 : i8
109
+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
110
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x3 x2 xi8 >
111
+ return %res : vector <2 x3 x2 xi8 >
112
+ }
113
+
114
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
115
+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
116
+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
118
+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119
+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120
+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121
+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122
+ // CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
123
+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124
+ // CHECK: return %[[VEC]] : vector<2x3x2xi8>
125
+
126
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
96
127
// CHECK-128B: memref.collapse_shape
97
128
98
129
// -----
@@ -380,7 +411,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
380
411
381
412
// -----
382
413
383
- func.func @transfer_write_dims_mismatch_contiguous (
414
+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
384
415
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
385
416
%vec : vector <1 x1 x2 x2 xi8 >) {
386
417
@@ -390,15 +421,41 @@ func.func @transfer_write_dims_mismatch_contiguous(
390
421
return
391
422
}
392
423
393
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
424
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
394
425
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
395
426
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
396
427
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
397
428
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
398
429
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
399
430
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
400
431
401
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
432
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
433
+ // CHECK-128B: memref.collapse_shape
434
+
435
+ // -----
436
+
437
+ func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
438
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
439
+ %vec : vector <2 x2 xi8 >) {
440
+
441
+ %c0 = arith.constant 0 : index
442
+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
443
+ vector <2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
444
+ return
445
+ }
446
+
447
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
448
+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
449
+ // CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
450
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
451
+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
452
+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
453
+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
454
+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
455
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
456
+ // CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
457
+
458
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
402
459
// CHECK-128B: memref.collapse_shape
403
460
404
461
// -----
@@ -620,6 +677,7 @@ func.func @negative_out_of_bound_transfer_read(
620
677
}
621
678
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
622
679
// CHECK-NOT: memref.collapse_shape
680
+ // CHECK-NOT: vector.shape_cast
623
681
624
682
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
625
683
// CHECK-128B-NOT: memref.collapse_shape
@@ -638,45 +696,9 @@ func.func @negative_out_of_bound_transfer_write(
638
696
}
639
697
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
640
698
// CHECK-NOT: memref.collapse_shape
699
+ // CHECK-NOT: vector.shape_cast
641
700
642
701
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
643
702
// CHECK-128B-NOT: memref.collapse_shape
644
703
// CHECK-128B-NOT: vector.shape_cast
645
704
646
- // -----
647
-
648
- func.func @discontig_mem_contig_slice (
649
- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x1 x8 xi32 >) {
650
- %c0 = arith.constant 0 : index
651
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
652
- vector <1 x1 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
653
- return
654
- }
655
-
656
- // CHECK-LABEL: func.func @discontig_mem_contig_slice
657
- // CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
658
- // CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
659
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
660
- // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
661
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
662
- // CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
663
-
664
- // CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
665
- // CHECK-128B-NOT: vector.shape_cast
666
-
667
- // -----
668
-
669
- func.func @discontig_mem_discontig_slice (
670
- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x2 x8 xi32 >) {
671
- %c0 = arith.constant 0 : index
672
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
673
- vector <1 x2 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
674
- return
675
- }
676
-
677
- // CHECK-LABEL: func.func @discontig_mem_discontig_slice
678
- // CHECK-NOT: vector.shape_cast
679
-
680
- // CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
681
- // CHECK-128B-NOT: vector.shape_cast
682
-
0 commit comments