Skip to content

Commit a887196

Browse files
committed
Fixups
1 parent c07a4a2 commit a887196

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -376,27 +376,36 @@ struct LegalizeTransferWriteOpsByDecomposition
376376

377377
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
378378
/// part of type decomposition as at this level we know each tile write is
379-
/// disjoint, but that information is lost after decomposition (without
380-
/// static analysis).
379+
/// disjoint, but that information is lost after decomposition (without analysis
380+
/// to reconstruct it).
381381
///
382-
/// Example (in pseudo-MLIR):
382+
/// Example:
383383
///
384384
/// ```
385-
/// vector.transfer_write vector, dest[x, y], mask
386-
/// : vector<[16]x[4]xf32>, memref<?x?xf32>
385+
/// vector.transfer_write %vector, %dest[%y, %x], %mask
386+
/// : vector<[16]x[8]xi16>, memref<?x?xi16>
387387
/// ```
388388
/// Is rewritten to:
389389
/// ```
390-
/// for i in range (0, 4 * vscale) {
391-
/// let sliceRow = i + tile_n.row * vscale; ─┐
392-
/// let sliceCol = tile_n.col * vscale; |
393-
/// slice = vector.extract tile_n[i] |
394-
/// : vector<[4]xf32> from vector<[16]x[4]xf32> |
395-
/// slice_mask = vector.extract mask[sliceRow] |- Repeated 4x for
396-
/// : vector<[4]xi1> from vector<[16]x[4]xi1> | all tiles in
397-
/// vector.transfer_write | [16]x[4]
398-
/// slice, dest[x + sliceRow, y + sliceCol], slice_mask |
399-
/// : vector<[4]xf32>, memref<?x?xf32> ┘
390+
/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
391+
/// %upper_slice_y = arith.addi %slice_idx, %y : index
392+
/// %upper_slice_mask = vector.extract %mask[%slice_idx]
393+
/// : vector<[8]xi1> from vector<[16]x[8]xi1>
394+
/// %upper_slice = vector.extract %upper_tile[%slice_idx]
395+
/// : vector<[8]xi16> from vector<[8]x[8]xi16>
396+
/// vector.transfer_write %upper_slice,
397+
/// %dest[%upper_slice_y, %x], %upper_slice_mask
398+
/// : vector<[8]xi16>, memref<?x?xi16>
399+
/// // Same again for the lower tile:
400+
/// %lower_slice_idx = arith.addi %c8_vscale, %slice_idx : index
401+
/// %lower_slice_y = arith.addi %lower_slice_idx, %y : index
402+
/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx]
403+
/// : vector<[8]xi1> from vector<[16]x[8]xi1>
404+
/// %lower_slice = vector.extract %lower_tile[%slice_idx]
405+
/// : vector<[8]xi16> from vector<[8]x[8]xi16>
406+
/// vector.transfer_write %lower_slice,
407+
/// %dest[%lower_slice_y, %x], %lower_slice_mask
408+
/// : vector<[8]xi16>, memref<?x?xi16>
400409
/// }
401410
/// ```
402411
struct LegalizeMultiTileTransferWriteAsStoreLoop

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,35 @@ func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0:
248248

249249
// -----
250250

251+
// Tensor semantics are not supported for the store loop lowering.
252+
253+
// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
254+
// CHECK-NOT: scf.for
255+
func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>)
256+
{
257+
%c0 = arith.constant 0 : index
258+
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
259+
return
260+
}
261+
262+
// -----
263+
264+
#transpose = affine_map<(d0, d1) -> (d1, d0)>
265+
266+
// Transposes are not supported for the store loop lowering.
267+
268+
// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
269+
// CHECK-NOT: scf.for
270+
func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
271+
{
272+
%c0 = arith.constant 0 : index
273+
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
274+
vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
275+
return
276+
}
277+
278+
// -----
279+
251280
#transpose = affine_map<(d0, d1) -> (d1, d0)>
252281

253282
// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(

0 commit comments

Comments
 (0)