Skip to content

Commit b0b69fd

Browse files
authored
[mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings (#73922)
Since #73253, loops over tiles in SSA form (i.e. loops that take `iter_args` and yield a new tile) are supported, so this patch updates ArmSME lowerings to this form. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms. Example: IR before: ```mlir scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> } // ... later use %tile ``` IR now: ```mlir %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { %tile_update = arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %iter_tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> scf.yield %tile_update : vector<[4]x[4]xi32> } // ... later use %broadcast_to_tile ```
1 parent c4a77bf commit b0b69fd

File tree

5 files changed

+135
-103
lines changed

5 files changed

+135
-103
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
6161
/// AFTER:
6262
/// ```mlir
6363
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
64-
/// %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
64+
/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
6565
/// %vscale = vector.vscale
6666
/// %c0 = arith.constant 0 : index
6767
/// %c1 = arith.constant 1 : index
6868
/// %min_svl_s = arith.constant 4 : index
6969
/// %svl_s = arith.muli %min_svl_s, %vscale : index
70-
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
70+
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
71+
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
7172
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
72-
/// %ptrue_s, %tile, %tile_slice_idx
73+
/// %ptrue_s, %iter_tile, %tile_slice_idx
7374
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
75+
/// scf.yield %tile_update : vector<[4]x[4]xi32>
7476
/// }
7577
/// ```
7678
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
8890
auto tileElementType = tileType.getElementType();
8991

9092
// Allocate a new SME tile.
91-
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
93+
auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
9294
rewriter, loc, tileType);
9395

9496
// Create a loop that loads each ZA tile slice from memory.
@@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
103105
// ..., SVL_Q).
104106
auto numTileSlices =
105107
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
106-
auto forOp =
107-
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
108+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
109+
step, ValueRange{initTile});
108110

109111
rewriter.setInsertionPointToStart(forOp.getBody());
110112

@@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
121123
getMemrefIndices(tileLoadOp.getIndices(),
122124
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
123125
numTileSlices, memrefIndices, loc, rewriter);
124-
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
125-
rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
126-
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
126+
auto currentTile = forOp.getRegionIterArg(0);
127+
auto loadSlice =
128+
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
129+
rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
130+
currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
131+
rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
127132

128133
rewriter.setInsertionPointAfter(forOp);
129134

130-
// Replace 'arm_sme.tile_load' with the tile.
131-
rewriter.replaceOp(tileLoadOp, tile);
135+
// Replace 'arm_sme.tile_load' with the result.
136+
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
132137

133138
return success();
134139
}
@@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
150155
/// ```mlir
151156
/// %c0 = arith.constant 0 : index
152157
/// %c1 = arith.constant 1 : index
153-
/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
158+
/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
154159
/// %num_rows = arith.constant 2 : index
155160
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
156-
/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
161+
/// %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
162+
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
157163
/// %tile_update = arm_sme.load_tile_slice
158-
/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
164+
/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
159165
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
166+
/// scf.yield %tile_update : vector<[4]x[4]xi32>
160167
/// }
161168
/// ```
162169
///
@@ -202,32 +209,36 @@ struct TileLoadOpWithMaskAndPadZeroConversion
202209
// Initialize tile with zero to satisfy padding. Inactive cols will be
203210
// zeroed anyway since the loads use zeroing predication. For inactive rows
204211
// however, no load will occur so these need to be zeroed.
205-
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
212+
auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
206213
rewriter, loc, tileType);
207214

208215
// Create a loop to load the active tile slices from memory.
209216
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
210217
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
211218
auto upperBound = numRows;
212-
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
219+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
220+
ValueRange{initTile});
213221

214222
rewriter.setInsertionPointToStart(forOp.getBody());
215223

216224
// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
217225
// tile.
218226
SmallVector<Value> memrefIndices;
219227
auto tileSliceIndex = forOp.getInductionVar();
228+
auto currentTile = forOp.getRegionIterArg(0);
220229
getMemrefIndices(tileLoadOp.getIndices(),
221230
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
222231
upperBound, memrefIndices, loc, rewriter);
223-
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
224-
rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
225-
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
232+
auto loadSlice =
233+
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
234+
rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
235+
currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
236+
rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
226237

227238
rewriter.setInsertionPointAfter(forOp);
228239

229-
// Replace 'arm_sme.tile_load' with the tile.
230-
rewriter.replaceOp(tileLoadOp, tile);
240+
// Replace 'arm_sme.tile_load' with the result.
241+
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
231242

232243
return success();
233244
}
@@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
249260
/// ```mlir
250261
/// ...
251262
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
252-
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
263+
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
264+
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
253265
/// ...
254266
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
255267
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
256268
/// : memref<?x?xi32>, vector<[4]xi1>,
257269
/// vector<[4]xi32> into vector<[4]xi32>
258270
/// // Insert slice into tile
259-
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
260-
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
271+
/// %tile_update = arm_sme.move_vector_to_tile_slice
272+
/// %slice, %iter_tile, %tile_slice_idx :
273+
/// vector<[4]xi32> into vector<[4]x[4]xi32>
274+
/// scf.yield %tile_update : vector<[4]x[4]xi32>
261275
/// }
262276
/// ```
263277
struct TileLoadOpWithMaskAndPadNonZeroConversion
@@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
298312
loc, rewriter.getI32Type(), numCols);
299313

300314
// Allocate a new SME tile.
301-
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
315+
auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
302316
rewriter, loc, tileType);
303317

304318
// Create a loop that loads each ZA tile slice from memory.
@@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
310324
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
311325
auto numTileSlices =
312326
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
313-
auto forOp =
314-
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
327+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
328+
step, ValueRange{initTile});
315329

316330
rewriter.setInsertionPointToStart(forOp.getBody());
317331

318332
auto tileSliceIndex = forOp.getInductionVar();
333+
auto currentTile = forOp.getRegionIterArg(0);
319334

320335
// Combine masks.
321336
auto rowIsActive = rewriter.create<arith::CmpIOp>(
@@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
344359
/*passthru=*/pad1DOp);
345360

346361
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
347-
tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
348-
rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
349-
tileLoadOp.getLayout());
362+
auto moveSlice =
363+
tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
364+
rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
365+
tileSliceIndex, tileLoadOp.getLayout());
366+
rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
350367

351368
rewriter.setInsertionPointAfter(forOp);
352369

353-
// Replace 'arm_sme.tile_load' with the tile.
354-
rewriter.replaceOp(tileLoadOp, tile);
370+
// Replace 'arm_sme.tile_load' with the result.
371+
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
355372

356373
return success();
357374
}

0 commit comments

Comments
 (0)