1
1
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
2
2
3
- func.func @transfer_read_flattenable_with_offset (
3
+ func.func @transfer_read_dims_match_contiguous (
4
4
%arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <5 x4 x3 x2 xi8 > {
5
5
%c0 = arith.constant 0 : index
6
6
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
9
9
return %v : vector <5 x4 x3 x2 xi8 >
10
10
}
11
11
12
- // CHECK-LABEL: func @transfer_read_flattenable_with_offset
12
+ // CHECK-LABEL: func @transfer_read_dims_match_contiguous
13
13
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
14
14
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
15
15
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,15 +18,53 @@ func.func @transfer_read_flattenable_with_offset(
18
18
19
19
// -----
20
20
21
- func.func @transfer_write_flattenable_with_offset (
21
+ // The shape of the memref and the vector don't match, but the vector is a
22
+ // contiguous subset of the memref, so "flattenable".
23
+
24
+ func.func @transfer_read_dims_mismatch_contiguous (
25
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
26
+ %c0 = arith.constant 0 : index
27
+ %cst = arith.constant 0 : i8
28
+ %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
29
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
30
+ return %v : vector <1 x1 x2 x2 xi8 >
31
+ }
32
+
33
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
34
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
35
+ // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
36
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
37
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
38
+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
39
+ // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
40
+ // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
41
+
42
+ // -----
43
+
44
+ func.func @transfer_read_dims_mismatch_non_contiguous (
45
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x1 x2 x2 xi8 > {
46
+ %c0 = arith.constant 0 : index
47
+ %cst = arith.constant 0 : i8
48
+ %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
49
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x1 x2 x2 xi8 >
50
+ return %v : vector <2 x1 x2 x2 xi8 >
51
+ }
52
+
53
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
54
+ // CHECK-NOT: memref.collapse_shape
55
+ // CHECK-NOT: vector.shape_cast
56
+
57
+ // -----
58
+
59
+ func.func @transfer_write_dims_match_contiguous (
22
60
%arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <5 x4 x3 x2 xi8 >) {
23
61
%c0 = arith.constant 0 : index
24
62
vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
25
63
vector <5 x4 x3 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
26
64
return
27
65
}
28
66
29
- // CHECK-LABEL: func @transfer_write_flattenable_with_offset
67
+ // CHECK-LABEL: func @transfer_write_dims_match_contiguous
30
68
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
31
69
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
32
70
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +73,48 @@ func.func @transfer_write_flattenable_with_offset(
35
73
36
74
// -----
37
75
76
+ func.func @transfer_write_dims_mismatch_contiguous (
77
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <1 x1 x2 x2 xi8 >) {
78
+ %c0 = arith.constant 0 : index
79
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
80
+ vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
81
+ return
82
+ }
83
+
84
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
85
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
86
+ // CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
87
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
88
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
89
+ // CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
90
+ // CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
91
+ // CHECK: return
92
+ // CHECK: }
93
+
94
+ // -----
95
+
96
+ func.func @transfer_write_dims_mismatch_non_contiguous (
97
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <2 x1 x2 x2 xi8 >) {
98
+ %c0 = arith.constant 0 : index
99
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
100
+ vector <2 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
101
+ return
102
+ }
103
+
104
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
105
+ // CHECK-NOT: memref.collapse_shape
106
+ // CHECK-NOT: vector.shape_cast
107
+
108
+ // -----
109
+
38
110
func.func @transfer_write_0d (%arg : memref <i8 >, %vec : vector <i8 >) {
39
111
vector.transfer_write %vec , %arg [] : vector <i8 >, memref <i8 >
40
112
return
41
113
}
42
114
43
- // CHECK-LABEL: func @transfer_write_0d
44
- // CHECK-SAME: %[[ARG:.+]]: memref<i8>
45
- // CHECK-SAME: %[[VEC:.+]]: vector<i8>
46
- // CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
47
- // CHECK: return
115
+ // CHECK-LABEL: func.func @transfer_write_0d
116
+ // CHECK-NOT: memref.collapse_shape
117
+ // CHECK-NOT: vector.shape_cast
48
118
49
119
// -----
50
120
@@ -54,11 +124,9 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
54
124
return %0 : vector <i8 >
55
125
}
56
126
57
- // CHECK-LABEL: func @transfer_read_0d
58
- // CHECK-SAME: %[[ARG:.+]]: memref<i8>
59
- // CHECK: %[[CST:.+]] = arith.constant 0 : i8
60
- // CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
61
- // CHECK: return %[[READ]]
127
+ // CHECK-LABEL: func.func @transfer_read_0d
128
+ // CHECK-NOT: memref.collapse_shape
129
+ // CHECK-NOT: vector.shape_cast
62
130
63
131
// -----
64
132
0 commit comments