@@ -179,9 +179,9 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
179
179
180
180
func.func @leading_scalable_dimension_transfer_read (%dest : memref <24 x1 xf32 >, %index: index ) -> vector <[4 ]x1 xf32 > {
181
181
%c0 = arith.constant 0 : index
182
- %cst_0 = arith.constant 0.000000e+00 : f32
183
- %4 = vector.transfer_read %dest [%index , %c0 ], %cst_0 {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <[4 ]x1 xf32 >
184
- return %4 : vector <[4 ]x1 xf32 >
182
+ %pad = arith.constant 0.0 : f32
183
+ %0 = vector.transfer_read %dest [%index , %c0 ], %pad {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <[4 ]x1 xf32 >
184
+ return %0 : vector <[4 ]x1 xf32 >
185
185
}
186
186
// CHECK: func.func @leading_scalable_dimension_transfer_read
187
187
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -196,11 +196,39 @@ func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %i
196
196
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
197
197
func.func @trailing_scalable_one_dim_transfer_read (%dest : memref <24 x1 xf32 >, %index: index ) -> vector <4 x[1 ]xf32 > {
198
198
%c0 = arith.constant 0 : index
199
- %cst_0 = arith.constant 0.000000e+00 : f32
200
- %4 = vector.transfer_read %dest [%index , %c0 ], %cst_0 {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <4 x[1 ]xf32 >
201
- return %4 : vector <4 x[1 ]xf32 >
199
+ %pad = arith.constant 0.0 : f32
200
+ %0 = vector.transfer_read %dest [%index , %c0 ], %pad {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <4 x[1 ]xf32 >
201
+ return %0 : vector <4 x[1 ]xf32 >
202
202
}
203
203
// CHECK: func.func @trailing_scalable_one_dim_transfer_read
204
204
// CHECK-NOT: vector.shape_cast
205
205
// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
206
206
// CHECK-NOT: vector.shape_cast
207
+
208
+ // -----
209
+
210
+ func.func @leading_scalable_dimension_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <[4 ]x1 xf32 >, %index: index ) {
211
+ %c0 = arith.constant 0 : index
212
+ vector.transfer_write %vec , %dest [%index , %c0 ] {in_bounds = [true , true ]} : vector <[4 ]x1 xf32 >, memref <24 x1 xf32 >
213
+ return
214
+ }
215
+ // CHECK: func.func @leading_scalable_dimension_transfer_write
216
+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
217
+ // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
218
+ // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
219
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
220
+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
221
+ // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]][%[[IDX]]] {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
222
+
223
+ // -----
224
+
225
+ // Negative test: [1] (scalable 1) is _not_ a unit dimension.
226
+ func.func @trailing_scalable_one_dim_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <4 x[1 ]xf32 >, %index: index ) {
227
+ %c0 = arith.constant 0 : index
228
+ vector.transfer_write %vec , %dest [%index , %c0 ] {in_bounds = [true , true ]} : vector <4 x[1 ]xf32 >, memref <24 x1 xf32 >
229
+ return
230
+ }
231
+ // CHECK: func.func @trailing_scalable_one_dim_transfer_write
232
+ // CHECK-NOT: vector.shape_cast
233
+ // CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
234
+ // CHECK-NOT: vector.shape_cast
0 commit comments