@@ -376,27 +376,36 @@ struct LegalizeTransferWriteOpsByDecomposition
376
376
377
377
// / Legalize a multi-tile transfer_write as a single store loop. This is done as
378
378
// / part of type decomposition as at this level we know each tile write is
379
- // / disjoint, but that information is lost after decomposition (without
380
- // / static analysis ).
379
+ // / disjoint, but that information is lost after decomposition (without analysis
380
+ // / to reconstruct it ).
381
381
// /
382
- // / Example (in pseudo-MLIR) :
382
+ // / Example:
383
383
// /
384
384
// / ```
385
- // / vector.transfer_write vector, dest[x, y ], mask
386
- // / : vector<[16]x[4]xf32 >, memref<?x?xf32 >
385
+ // / vector.transfer_write % vector, % dest[%y, %x ], % mask
386
+ // / : vector<[16]x[8]xi16 >, memref<?x?xi16 >
387
387
// / ```
388
388
// / Is rewritten to:
389
389
// / ```
390
- // / for i in range (0, 4 * vscale) {
391
- // / let sliceRow = i + tile_n.row * vscale; ─┐
392
- // / let sliceCol = tile_n.col * vscale; |
393
- // / slice = vector.extract tile_n[i] |
394
- // / : vector<[4]xf32> from vector<[16]x[4]xf32> |
395
- // / slice_mask = vector.extract mask[sliceRow] |- Repeated 4x for
396
- // / : vector<[4]xi1> from vector<[16]x[4]xi1> | all tiles in
397
- // / vector.transfer_write | [16]x[4]
398
- // / slice, dest[x + sliceRow, y + sliceCol], slice_mask |
399
- // / : vector<[4]xf32>, memref<?x?xf32> ┘
390
+ // / scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
391
+ // / %upper_slice_y = arith.addi %slice_idx, %y : index
392
+ // / %upper_slice_mask = vector.extract %mask[%slice_idx]
393
+ // / : vector<[8]xi1> from vector<[16]x[8]xi1>
394
+ // / %upper_slice = vector.extract %upper_tile[%slice_idx]
395
+ // / : vector<[8]xi16> from vector<[8]x[8]xi16>
396
+ // / vector.transfer_write %upper_slice,
397
+ // / %dest[%upper_slice_y, %x], %upper_slice_mask
398
+ // / : vector<[8]xi16>, memref<?x?xi16>
399
+ // / // Same again for the lower tile:
400
+ // / %lower_slice_idx = arith.addi %c8_vscale, %slice_idx : index
401
+ // / %lower_slice_y = arith.addi %lower_slice_idx, %y : index
402
+ // / %lower_slice_mask = vector.extract %mask[%lower_slice_idx]
403
+ // / : vector<[8]xi1> from vector<[16]x[8]xi1>
404
+ // / %lower_slice = vector.extract %lower_tile[%slice_idx]
405
+ // / : vector<[8]xi16> from vector<[8]x[8]xi16>
406
+ // / vector.transfer_write %lower_slice,
407
+ // / %dest[%lower_slice_y, %x], %lower_slice_mask
408
+ // / : vector<[8]xi16>, memref<?x?xi16>
400
409
// / }
401
410
// / ```
402
411
struct LegalizeMultiTileTransferWriteAsStoreLoop
0 commit comments