@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
174
174
func.func @transfer_write_f16_scalable_16x8 (%dest: memref <?x?xf16 >, %vec: vector <[16 ]x[8 ]xf16 >)
175
175
{
176
176
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
177
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
177
178
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
178
179
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
179
180
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
180
- // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181
- // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181
+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
182
+ // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183
+ // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184
+ // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185
+ // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186
+ // CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187
+ // CHECK-NEXT: }
182
188
// CHECK-NEXT: return
183
189
%c0 = arith.constant 0 : index
184
190
vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[16 ]x[8 ]xf16 >, memref <?x?xf16 >
@@ -201,6 +207,90 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
201
207
202
208
// -----
203
209
210
+ // CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
211
+ // CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
212
+ // CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index,
213
+ // CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index,
214
+ // CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
215
+ // CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
216
+ // CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
217
+ // CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
218
+ func.func @transfer_write_f32_scalable_8x8_masked (%dest: memref <?x?xf32 >, %dim0: index , %dim1: index , %vec: vector <[8 ]x[8 ]xf32 >)
219
+ {
220
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
221
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
222
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
223
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
224
+ // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
225
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
226
+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
227
+ // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
228
+ // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
229
+ // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
230
+ // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
231
+ // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
232
+ // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
233
+ // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
234
+ // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
235
+ // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
236
+ // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
237
+ // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
238
+ // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
239
+ // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
240
+ // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
241
+ // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
242
+ // CHECK-NEXT: }
243
+ %c0 = arith.constant 0 : index
244
+ %mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x[8 ]xi1 >
245
+ vector.transfer_write %vec , %dest [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
246
+ return
247
+ }
248
+
249
+ // -----
250
+
251
+ // Tensor semantics are not supported for the store loop lowering.
252
+
253
+ // CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
254
+ // CHECK-NOT: scf.for
255
+ func.func @negative_transfer_write_f32_scalable_8x8_tensor (%dest: tensor <?x?xf32 >, %vec: vector <[8 ]x[8 ]xf32 >)
256
+ {
257
+ %c0 = arith.constant 0 : index
258
+ vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, tensor <?x?xf32 >
259
+ return
260
+ }
261
+
262
+ // -----
263
+
264
+ #transpose = affine_map <(d0 , d1 ) -> (d1 , d0 )>
265
+
266
+ // Transposes are not supported for the store loop lowering.
267
+
268
+ // CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
269
+ // CHECK-NOT: scf.for
270
+ func.func @negative_transfer_write_f32_scalable_8x8_tensor (%dest: tensor <?x?xf32 >, %dim0: index , %dim1: index , %vec: vector <[8 ]x[8 ]xf32 >)
271
+ {
272
+ %c0 = arith.constant 0 : index
273
+ %mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x[8 ]xi1 >
274
+ vector.transfer_write %vec , %dest [%c0 , %c0 ], %mask {permutation_map = #transpose , in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, tensor <?x?xf32 >
275
+ return
276
+ }
277
+
278
+ // -----
279
+
280
+ // Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.
281
+
282
+ // CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
283
+ // CHECK-NOT: scf.for
284
+ func.func @negative_transfer_write_f32_scalable_32x32 (%dest: memref <?x?xf32 >, %dim0: index , %dim1: index , %vec: vector <[32 ]x[32 ]xf32 >)
285
+ {
286
+ %c0 = arith.constant 0 : index
287
+ %mask = vector.create_mask %dim0 , %dim1 : vector <[32 ]x[32 ]xi1 >
288
+ vector.transfer_write %vec , %dest [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[32 ]x[32 ]xf32 >, memref <?x?xf32 >
289
+ return
290
+ }
291
+
292
+ // -----
293
+
204
294
#transpose = affine_map <(d0 , d1 ) -> (d1 , d0 )>
205
295
206
296
// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +299,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
209
299
func.func @transpose_f32_scalable_4x16_via_read (%src: memref <?x?xf32 >, %dest: memref <?x?xf32 >)
210
300
{
211
301
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
302
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
212
303
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
213
304
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
214
305
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +312,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
221
312
// CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
222
313
// CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
223
314
// CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
224
- // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
225
- // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
226
- // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
227
- // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
315
+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
316
+ // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
317
+ // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
318
+ // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
319
+ // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
320
+ // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
321
+ // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
322
+ // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
323
+ // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
324
+ // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
325
+ // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
326
+ // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
327
+ // CHECK-NEXT: }
228
328
// CHECK-NEXT: return
229
329
%c0 = arith.constant 0 : index
230
330
%pad = arith.constant 0.0 : f32
0 commit comments