@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
174
174
func.func @transfer_write_f16_scalable_16x8 (%dest: memref <?x?xf16 >, %vec: vector <[16 ]x[8 ]xf16 >)
175
175
{
176
176
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
177
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
177
178
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
178
179
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
179
180
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
180
- // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181
- // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181
+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
182
+ // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183
+ // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184
+ // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185
+ // CHECK-NEXT: %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186
+ // CHECK-NEXT: vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187
+ // CHECK-NEXT: }
182
188
// CHECK-NEXT: return
183
189
%c0 = arith.constant 0 : index
184
190
vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[16 ]x[8 ]xf16 >, memref <?x?xf16 >
@@ -201,6 +207,47 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
201
207
202
208
// -----
203
209
210
+ // CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
211
+ // CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
212
+ // CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index,
213
+ // CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index,
214
+ // CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
215
+ // CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
216
+ // CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
217
+ // CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
218
+ func.func @transfer_write_f32_scalable_8x8_masked (%dest: memref <?x?xf32 >, %dim0: index , %dim1: index , %vec: vector <[8 ]x[8 ]xf32 >)
219
+ {
220
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
221
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
222
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
223
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
224
+ // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
225
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
226
+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
227
+ // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
228
+ // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
229
+ // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
230
+ // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
231
+ // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
232
+ // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
233
+ // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
234
+ // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
235
+ // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
236
+ // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
237
+ // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
238
+ // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
239
+ // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
240
+ // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
241
+ // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
242
+ // CHECK-NEXT: }
243
+ %c0 = arith.constant 0 : index
244
+ %mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x[8 ]xi1 >
245
+ vector.transfer_write %vec , %dest [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
246
+ return
247
+ }
248
+
249
+ // -----
250
+
204
251
#transpose = affine_map <(d0 , d1 ) -> (d1 , d0 )>
205
252
206
253
// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +256,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
209
256
func.func @transpose_f32_scalable_4x16_via_read (%src: memref <?x?xf32 >, %dest: memref <?x?xf32 >)
210
257
{
211
258
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
259
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
212
260
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
213
261
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
214
262
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +269,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
221
269
// CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
222
270
// CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
223
271
// CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
224
- // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
225
- // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
226
- // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
227
- // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
272
+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
273
+ // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
274
+ // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
275
+ // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
276
+ // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
277
+ // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
278
+ // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
279
+ // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
280
+ // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
281
+ // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
282
+ // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
283
+ // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
284
+ // CHECK-NEXT: }
228
285
// CHECK-NEXT: return
229
286
%c0 = arith.constant 0 : index
230
287
%pad = arith.constant 0.0 : f32
0 commit comments