@@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
16
16
17
17
// -----
18
18
19
+ func.func @contiguous_outer_dyn_inner_most_view (%in: memref <?x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
20
+ %c0 = arith.constant 0 : index
21
+ %cst = arith.constant 0.0 : f32
22
+ %0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <?x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x1 xf32 >
23
+ return %0 : vector <1 x8 x1 xf32 >
24
+ }
25
+ // CHECK: func @contiguous_outer_dyn_inner_most_view(
26
+ // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
27
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
28
+ // CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
29
+ // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
30
+ // CHECK-SAME: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
31
+ // CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
32
+ // CHECK-SAME: memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
33
+ // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
34
+ // CHECK: return %[[RESULT]]
35
+
36
+ // -----
37
+
19
38
func.func @contiguous_inner_most_dim (%A: memref <16 x1 xf32 >, %i:index , %j:index ) -> (vector <8 x1 xf32 >) {
20
39
%c0 = arith.constant 0 : index
21
40
%f0 = arith.constant 0.0 : f32
@@ -119,6 +138,27 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
119
138
120
139
// -----
121
140
141
+ func.func @outer_dyn_drop_inner_most_dim_for_transfer_write (%arg0: memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
142
+ %c0 = arith.constant 0 : index
143
+ vector.transfer_write %arg1 , %arg0 [%arg2 , %c0 , %c0 , %c0 ]
144
+ {in_bounds = [true , true , true , true ]}
145
+ : vector <1 x16 x16 x1 xf32 >, memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>
146
+ return
147
+ }
148
+ // CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
149
+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
150
+ // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
151
+ // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
152
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
153
+ // CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
154
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
155
+ // CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
156
+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
157
+ // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
158
+ // CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]
159
+
160
+ // -----
161
+
122
162
func.func @non_unit_strides (%arg0: memref <512 x16 x1 xf32 , strided <[8192 , 16 , 4 ], offset : ?>>, %arg1: vector <16 x16 x1 xf32 >, %arg2: index ) {
123
163
%c0 = arith.constant 0 : index
124
164
vector.transfer_write %arg1 , %arg0 [%arg2 , %c0 , %c0 ]
0 commit comments