@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
70
70
71
71
// -----
72
72
73
- // The shape of the memref and the vector don't match, but the vector is a
74
- // contiguous subset of the memref, so "flattenable".
73
+ // The shape of the memref and the vector don't match, but the vector,
74
+ // ignoring the unit dimensions, is a contiguous subset of the memref,
75
+ // so "flattenable"
75
76
76
- func.func @transfer_read_dims_mismatch_contiguous (
77
+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
77
78
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
78
79
79
80
%c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
83
84
return %res : vector <1 x1 x2 x2 xi8 >
84
85
}
85
86
86
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
87
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
87
88
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
88
89
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
89
90
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
92
93
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
93
94
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
94
95
95
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
97
+ // CHECK-128B: memref.collapse_shape
98
+
99
+ // -----
100
+
101
+ // The shape of the memref and the vector don't match, but the vector is a
102
+ // contiguous subset of the memref, so "flattenable"
103
+
104
+ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
105
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
106
+
107
+ %c0 = arith.constant 0 : index
108
+ %cst = arith.constant 0 : i8
109
+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
110
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x3 x2 xi8 >
111
+ return %res : vector <2 x3 x2 xi8 >
112
+ }
113
+
114
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
115
+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
116
+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
118
+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119
+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120
+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121
+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122
+ // CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
123
+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124
+ // CHECK: return %[[VEC]] : vector<2x3x2xi8>
125
+
126
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
96
127
// CHECK-128B: memref.collapse_shape
97
128
98
129
// -----
@@ -384,7 +415,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
384
415
385
416
// -----
386
417
387
- func.func @transfer_write_dims_mismatch_contiguous (
418
+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
388
419
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
389
420
%vec : vector <1 x1 x2 x2 xi8 >) {
390
421
@@ -394,15 +425,41 @@ func.func @transfer_write_dims_mismatch_contiguous(
394
425
return
395
426
}
396
427
397
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
428
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
398
429
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
399
430
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
400
431
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
401
432
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
402
433
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
403
434
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
404
435
405
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
436
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
437
+ // CHECK-128B: memref.collapse_shape
438
+
439
+ // -----
440
+
441
+ func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
442
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
443
+ %vec : vector <2 x2 xi8 >) {
444
+
445
+ %c0 = arith.constant 0 : index
446
+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
447
+ vector <2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
448
+ return
449
+ }
450
+
451
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
452
+ // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
453
+ // CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
454
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
455
+ // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
456
+ // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
457
+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
458
+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
459
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
460
+ // CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
461
+
462
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
406
463
// CHECK-128B: memref.collapse_shape
407
464
408
465
// -----
@@ -626,6 +683,7 @@ func.func @negative_out_of_bound_transfer_read(
626
683
}
627
684
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
628
685
// CHECK-NOT: memref.collapse_shape
686
+ // CHECK-NOT: vector.shape_cast
629
687
630
688
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
631
689
// CHECK-128B-NOT: memref.collapse_shape
@@ -642,45 +700,9 @@ func.func @negative_out_of_bound_transfer_write(
642
700
}
643
701
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
644
702
// CHECK-NOT: memref.collapse_shape
703
+ // CHECK-NOT: vector.shape_cast
645
704
646
705
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
647
706
// CHECK-128B-NOT: memref.collapse_shape
648
707
// CHECK-128B-NOT: vector.shape_cast
649
708
650
- // -----
651
-
652
- func.func @discontig_mem_contig_slice (
653
- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x1 x8 xi32 >) {
654
- %c0 = arith.constant 0 : index
655
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
656
- vector <1 x1 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
657
- return
658
- }
659
-
660
- // CHECK-LABEL: func.func @discontig_mem_contig_slice
661
- // CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
662
- // CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
663
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
664
- // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
665
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
666
- // CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
667
-
668
- // CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
669
- // CHECK-128B-NOT: vector.shape_cast
670
-
671
- // -----
672
-
673
- func.func @discontig_mem_discontig_slice (
674
- %mem : memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>, %vec : vector <1 x2 x8 xi32 >) {
675
- %c0 = arith.constant 0 : index
676
- vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} :
677
- vector <1 x2 x8 xi32 >, memref <8 x8 x8 xi32 , strided <[128 , 16 , 1 ]>>
678
- return
679
- }
680
-
681
- // CHECK-LABEL: func.func @discontig_mem_discontig_slice
682
- // CHECK-NOT: vector.shape_cast
683
-
684
- // CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
685
- // CHECK-128B-NOT: vector.shape_cast
686
-
0 commit comments