1
1
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
2
2
3
- func.func @contiguous_inner_most_view (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
3
+ //-----------------------------------------------------------------------------
4
+ // 1. vector.transfer_read
5
+ //-----------------------------------------------------------------------------
6
+
7
+ func.func @contiguous_inner_most (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
4
8
%c0 = arith.constant 0 : index
5
9
%cst = arith.constant 0.0 : f32
6
10
%0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x1 xf32 >
7
11
return %0 : vector <1 x8 x1 xf32 >
8
12
}
9
- // CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
13
+
14
+ // CHECK: func @contiguous_inner_most(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
10
15
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
11
16
// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
12
17
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
13
18
// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
14
19
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
15
20
// CHECK: return %[[RESULT]]
16
21
22
+ // Same as the top example within this split, but with the inner vector
23
+ // dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
24
+ // vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
25
+
26
+ func.func @contiguous_inner_most_scalable_inner_dim (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x[8 ]x1 xf32 >{
27
+ %c0 = arith.constant 0 : index
28
+ %cst = arith.constant 0.0 : f32
29
+ %0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x[8 ]x1 xf32 >
30
+ return %0 : vector <1 x[8 ]x1 xf32 >
31
+ }
32
+
33
+ // CHECK: func @contiguous_inner_most_scalable_inner_dim(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
34
+ // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
35
+ // CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
36
+ // CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
37
+ // CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x[8]xf32>
38
+ // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
39
+ // CHECK: return %[[RESULT]]
40
+
41
+ // Same as the top example within this split, but the trailing unit dim was
42
+ // replaced with a dyn dim - not supported
43
+
44
+ func.func @non_unit_trailing_dim (%in: memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
45
+ %c0 = arith.constant 0 : index
46
+ %cst = arith.constant 0.0 : f32
47
+ %0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x1 xf32 >
48
+ return %0 : vector <1 x8 x1 xf32 >
49
+ }
50
+
51
+ // CHECK-LABEL: func @non_unit_trailing_dim
52
+ // CHECK-NOT: memref.subview
53
+ // CHECK-NOT: vector.shape_cast
54
+
55
+ // Same as the top example within this split, but with a scalable unit dim in
56
+ // the output vector - not supported
57
+
58
+ func.func @negative_scalable_unit_dim (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x[1 ]xf32 >{
59
+ %c0 = arith.constant 0 : index
60
+ %cst = arith.constant 0.0 : f32
61
+ %0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x[1 ]xf32 >
62
+ return %0 : vector <1 x8 x[1 ]xf32 >
63
+ }
64
+ // CHECK-LABEL: func @scalable_unit_dim
65
+ // CHECK-NOT: memref.subview
66
+ // CHECK-NOT: vector.shape_cast
67
+
17
68
// -----
18
69
19
- func.func @contiguous_outer_dyn_inner_most_view (%a: index , %b: index , %memref: memref <?x?x8 x1 xf32 >) -> vector <8 x1 xf32 > {
70
+ func.func @contiguous_outer_dyn_inner_most (%a: index , %b: index , %memref: memref <?x?x8 x1 xf32 >) -> vector <8 x1 xf32 > {
20
71
%c0 = arith.constant 0 : index
21
72
%pad = arith.constant 0.0 : f32
22
73
%v = vector.transfer_read %memref [%a , %b , %c0 , %c0 ], %pad {in_bounds = [true , true ]} : memref <?x?x8 x1 xf32 >, vector <8 x1 xf32 >
23
74
return %v : vector <8 x1 xf32 >
24
75
}
25
- // CHECK: func.func @contiguous_outer_dyn_inner_most_view (
76
+ // CHECK: func.func @contiguous_outer_dyn_inner_most (
26
77
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]
27
78
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]
28
79
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
@@ -103,6 +154,10 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
103
154
104
155
// -----
105
156
157
+ //-----------------------------------------------------------------------------
158
+ // 2. vector.transfer_write
159
+ //-----------------------------------------------------------------------------
160
+
106
161
func.func @drop_two_inner_most_dim_for_transfer_write (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
107
162
%c0 = arith.constant 0 : index
108
163
vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
@@ -177,21 +232,6 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
177
232
178
233
// -----
179
234
180
- func.func @leading_scalable_dimension_transfer_read (%dest : memref <24 x1 xf32 >) -> vector <[4 ]x1 xf32 > {
181
- %c0 = arith.constant 0 : index
182
- %pad = arith.constant 0.0 : f32
183
- %0 = vector.transfer_read %dest [%c0 , %c0 ], %pad {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <[4 ]x1 xf32 >
184
- return %0 : vector <[4 ]x1 xf32 >
185
- }
186
- // CHECK: func.func @leading_scalable_dimension_transfer_read
187
- // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
188
- // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
189
- // CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
190
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
191
- // CHECK: return %[[CAST]]
192
-
193
- // -----
194
-
195
235
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
196
236
func.func @trailing_scalable_one_dim_transfer_read (%dest : memref <24 x1 xf32 >) -> vector <4 x[1 ]xf32 > {
197
237
%c0 = arith.constant 0 : index
@@ -217,16 +257,3 @@ func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %
217
257
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
218
258
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
219
259
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
220
-
221
- // -----
222
-
223
- // Negative test: [1] (scalable 1) is _not_ a unit dimension.
224
- func.func @trailing_scalable_one_dim_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <4 x[1 ]xf32 >, %index: index ) {
225
- %c0 = arith.constant 0 : index
226
- vector.transfer_write %vec , %dest [%index , %c0 ] {in_bounds = [true , true ]} : vector <4 x[1 ]xf32 >, memref <24 x1 xf32 >
227
- return
228
- }
229
- // CHECK: func.func @trailing_scalable_one_dim_transfer_write
230
- // CHECK-NOT: vector.shape_cast
231
- // CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
232
- // CHECK-NOT: vector.shape_cast
0 commit comments