1
1
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
2
2
3
- func.func @transfer_read_flattenable_with_offset (
3
+ func.func @transfer_read_dims_match_contiguous (
4
4
%arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <5 x4 x3 x2 xi8 > {
5
5
%c0 = arith.constant 0 : index
6
6
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
9
9
return %v : vector <5 x4 x3 x2 xi8 >
10
10
}
11
11
12
- // CHECK-LABEL: func @transfer_read_flattenable_with_offset
12
+ // CHECK-LABEL: func @transfer_read_dims_match_contiguous
13
13
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
14
14
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
15
15
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,15 +18,52 @@ func.func @transfer_read_flattenable_with_offset(
18
18
19
19
// -----
20
20
21
- func.func @transfer_write_flattenable_with_offset (
21
+ // The shape of the memref and the vector don't match, but the vector is a
22
+ // contiguous subset of the memref, so "flattenable".
23
+
24
+ func.func @transfer_read_dims_mismatch_contiguous (
25
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
26
+ %c0 = arith.constant 0 : index
27
+ %cst = arith.constant 0 : i8
28
+ %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
29
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
30
+ return %v : vector <1 x1 x2 x2 xi8 >
31
+ }
32
+
33
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
34
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
35
+ // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
36
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
37
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
38
+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
39
+ // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
40
+ // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
41
+
42
+ // -----
43
+
44
+ func.func @transfer_read_dims_mismatch_contiguous (
45
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x1 x2 x2 xi8 > {
46
+ %c0 = arith.constant 0 : index
47
+ %cst = arith.constant 0 : i8
48
+ %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
49
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x1 x2 x2 xi8 >
50
+ return %v : vector <2 x1 x2 x2 xi8 >
51
+ }
52
+
53
+ // CHECK-NOT: memref.collapse_shape
54
+ // CHECK-NOT: vector.shape_cast
55
+
56
+ // -----
57
+
58
+ func.func @transfer_write_dims_match_contiguous (
22
59
%arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <5 x4 x3 x2 xi8 >) {
23
60
%c0 = arith.constant 0 : index
24
61
vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
25
62
vector <5 x4 x3 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
26
63
return
27
64
}
28
65
29
- // CHECK-LABEL: func @transfer_write_flattenable_with_offset
66
+ // CHECK-LABEL: func @transfer_write_dims_match_contiguous
30
67
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
31
68
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
32
69
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +72,46 @@ func.func @transfer_write_flattenable_with_offset(
35
72
36
73
// -----
37
74
75
+ func.func @transfer_write_dims_mismatch_contiguous (
76
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <1 x1 x2 x2 xi8 >) {
77
+ %c0 = arith.constant 0 : index
78
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
79
+ vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
80
+ return
81
+ }
82
+
83
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
84
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
85
+ // CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
86
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
87
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
88
+ // CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
89
+ // CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
90
+ // CHECK: return
91
+ // CHECK: }
92
+
93
+ // -----
94
+
95
+ func.func @transfer_write_dims_mismatch_non_contiguous (
96
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <2 x1 x2 x2 xi8 >) {
97
+ %c0 = arith.constant 0 : index
98
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
99
+ vector <2 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
100
+ return
101
+ }
102
+
103
+ // CHECK-NOT: memref.collapse_shape
104
+ // CHECK-NOT: vector.shape_cast
105
+
106
+ // -----
107
+
38
108
func.func @transfer_write_0d (%arg : memref <i8 >, %vec : vector <i8 >) {
39
109
vector.transfer_write %vec , %arg [] : vector <i8 >, memref <i8 >
40
110
return
41
111
}
42
112
43
- // CHECK-LABEL: func @transfer_write_0d
44
- // CHECK-SAME: %[[ARG:.+]]: memref<i8>
45
- // CHECK-SAME: %[[VEC:.+]]: vector<i8>
46
- // CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
47
- // CHECK: return
113
+ // CHECK-NOT: memref.collapse_shape
114
+ // CHECK-NOT: vector.shape_cast
48
115
49
116
// -----
50
117
@@ -54,11 +121,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
54
121
return %0 : vector <i8 >
55
122
}
56
123
57
- // CHECK-LABEL: func @transfer_read_0d
58
- // CHECK-SAME: %[[ARG:.+]]: memref<i8>
59
- // CHECK: %[[CST:.+]] = arith.constant 0 : i8
60
- // CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
61
- // CHECK: return %[[READ]]
124
+ // CHECK-NOT: memref.collapse_shape
125
+ // CHECK-NOT: vector.shape_cast
62
126
63
127
// -----
64
128
0 commit comments