@@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
61
61
// / AFTER:
62
62
// / ```mlir
63
63
// / %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
64
- // / %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
64
+ // / %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
65
65
// / %vscale = vector.vscale
66
66
// / %c0 = arith.constant 0 : index
67
67
// / %c1 = arith.constant 1 : index
68
68
// / %min_svl_s = arith.constant 4 : index
69
69
// / %svl_s = arith.muli %min_svl_s, %vscale : index
70
- // / scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
70
+ // / %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
71
+ // / iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
71
72
// / %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
72
- // / %ptrue_s, %tile , %tile_slice_idx
73
+ // / %ptrue_s, %iter_tile , %tile_slice_idx
73
74
// / : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
75
+ // / scf.yield %tile_update : vector<[4]x[4]xi32>
74
76
// / }
75
77
// / ```
76
78
struct TileLoadOpConversion : public OpRewritePattern <arm_sme::TileLoadOp> {
@@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
88
90
auto tileElementType = tileType.getElementType ();
89
91
90
92
// Allocate a new SME tile.
91
- auto tile = tileLoadOp.createOpAndForwardTileId <arm_sme::GetTileOp>(
93
+ auto initTile = tileLoadOp.createOpAndForwardTileId <arm_sme::GetTileOp>(
92
94
rewriter, loc, tileType);
93
95
94
96
// Create a loop that loads each ZA tile slice from memory.
@@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
103
105
// ..., SVL_Q).
104
106
auto numTileSlices =
105
107
rewriter.create <arith::MulIOp>(loc, minTileSlices, vscale);
106
- auto forOp =
107
- rewriter. create <scf::ForOp>(loc, lowerBound, numTileSlices, step);
108
+ auto forOp = rewriter. create <scf::ForOp>(loc, lowerBound, numTileSlices,
109
+ step, ValueRange{initTile} );
108
110
109
111
rewriter.setInsertionPointToStart (forOp.getBody ());
110
112
@@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
121
123
getMemrefIndices (tileLoadOp.getIndices (),
122
124
tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
123
125
numTileSlices, memrefIndices, loc, rewriter);
124
- tileLoadOp.createOpAndForwardTileId <arm_sme::LoadTileSliceOp>(
125
- rewriter, loc, tileType, tileLoadOp.getBase (), allTruePredicate, tile,
126
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout ());
126
+ auto currentTile = forOp.getRegionIterArg (0 );
127
+ auto loadSlice =
128
+ tileLoadOp.createOpAndForwardTileId <arm_sme::LoadTileSliceOp>(
129
+ rewriter, loc, tileType, tileLoadOp.getBase (), allTruePredicate,
130
+ currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout ());
131
+ rewriter.create <scf::YieldOp>(loc, loadSlice.getResult ());
127
132
128
133
rewriter.setInsertionPointAfter (forOp);
129
134
130
- // Replace 'arm_sme.tile_load' with the tile .
131
- rewriter.replaceOp (tileLoadOp, tile );
135
+ // Replace 'arm_sme.tile_load' with the result .
136
+ rewriter.replaceOp (tileLoadOp, forOp. getResult ( 0 ) );
132
137
133
138
return success ();
134
139
}
@@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
150
155
// / ```mlir
151
156
// / %c0 = arith.constant 0 : index
152
157
// / %c1 = arith.constant 1 : index
153
- // / %tile = arm_sme.zero : vector<[4]x[4]xi32>
158
+ // / %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
154
159
// / %num_rows = arith.constant 2 : index
155
160
// / %num_cols = vector.create_mask %c4 : vector<[4]xi1>
156
- // / scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
161
+ // / %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
162
+ // / iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
157
163
// / %tile_update = arm_sme.load_tile_slice
158
- // / %src[%tile_slice_idx], %num_cols, %tile , %tile_slice_idx :
164
+ // / %src[%tile_slice_idx], %num_cols, %iter_tile , %tile_slice_idx :
159
165
// / memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
166
+ // / scf.yield %tile_update : vector<[4]x[4]xi32>
160
167
// / }
161
168
// / ```
162
169
// /
@@ -202,32 +209,36 @@ struct TileLoadOpWithMaskAndPadZeroConversion
202
209
// Initialize tile with zero to satisfy padding. Inactive cols will be
203
210
// zeroed anyway since the loads use zeroing predication. For inactive rows
204
211
// however, no load will occur so these need to be zeroed.
205
- auto tile = tileLoadOp.createOpAndForwardTileId <arm_sme::ZeroOp>(
212
+ auto initTile = tileLoadOp.createOpAndForwardTileId <arm_sme::ZeroOp>(
206
213
rewriter, loc, tileType);
207
214
208
215
// Create a loop to load the active tile slices from memory.
209
216
auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
210
217
auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
211
218
auto upperBound = numRows;
212
- auto forOp = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
219
+ auto forOp = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step,
220
+ ValueRange{initTile});
213
221
214
222
rewriter.setInsertionPointToStart (forOp.getBody ());
215
223
216
224
// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
217
225
// tile.
218
226
SmallVector<Value> memrefIndices;
219
227
auto tileSliceIndex = forOp.getInductionVar ();
228
+ auto currentTile = forOp.getRegionIterArg (0 );
220
229
getMemrefIndices (tileLoadOp.getIndices (),
221
230
tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
222
231
upperBound, memrefIndices, loc, rewriter);
223
- tileLoadOp.createOpAndForwardTileId <arm_sme::LoadTileSliceOp>(
224
- rewriter, loc, tileType, tileLoadOp.getBase (), numColsOp, tile,
225
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout ());
232
+ auto loadSlice =
233
+ tileLoadOp.createOpAndForwardTileId <arm_sme::LoadTileSliceOp>(
234
+ rewriter, loc, tileType, tileLoadOp.getBase (), numColsOp,
235
+ currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout ());
236
+ rewriter.create <scf::YieldOp>(loc, loadSlice.getResult ());
226
237
227
238
rewriter.setInsertionPointAfter (forOp);
228
239
229
- // Replace 'arm_sme.tile_load' with the tile .
230
- rewriter.replaceOp (tileLoadOp, tile );
240
+ // Replace 'arm_sme.tile_load' with the result .
241
+ rewriter.replaceOp (tileLoadOp, forOp. getResult ( 0 ) );
231
242
232
243
return success ();
233
244
}
@@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
249
260
// / ```mlir
250
261
// / ...
251
262
// / %pad_1d = arith.constant dense<1> : vector<[4]xi32>
252
- // / scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
263
+ // / %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
264
+ // / iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
253
265
// / ...
254
266
// / %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
255
267
// / %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
256
268
// / : memref<?x?xi32>, vector<[4]xi1>,
257
269
// / vector<[4]xi32> into vector<[4]xi32>
258
270
// / // Insert slice into tile
259
- // / arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
260
- // / : vector<[4]xi32> into vector<[4]x[4]xi32>
271
+ // / %tile_update = arm_sme.move_vector_to_tile_slice
272
+ // / %slice, %iter_tile, %tile_slice_idx :
273
+ // / vector<[4]xi32> into vector<[4]x[4]xi32>
274
+ // / scf.yield %tile_update : vector<[4]x[4]xi32>
261
275
// / }
262
276
// / ```
263
277
struct TileLoadOpWithMaskAndPadNonZeroConversion
@@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
298
312
loc, rewriter.getI32Type (), numCols);
299
313
300
314
// Allocate a new SME tile.
301
- auto tile = tileLoadOp.createOpAndForwardTileId <arm_sme::GetTileOp>(
315
+ auto initTile = tileLoadOp.createOpAndForwardTileId <arm_sme::GetTileOp>(
302
316
rewriter, loc, tileType);
303
317
304
318
// Create a loop that loads each ZA tile slice from memory.
@@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
310
324
auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
311
325
auto numTileSlices =
312
326
rewriter.create <arith::MulIOp>(loc, minTileSlices, vscale);
313
- auto forOp =
314
- rewriter. create <scf::ForOp>(loc, lowerBound, numTileSlices, step);
327
+ auto forOp = rewriter. create <scf::ForOp>(loc, lowerBound, numTileSlices,
328
+ step, ValueRange{initTile} );
315
329
316
330
rewriter.setInsertionPointToStart (forOp.getBody ());
317
331
318
332
auto tileSliceIndex = forOp.getInductionVar ();
333
+ auto currentTile = forOp.getRegionIterArg (0 );
319
334
320
335
// Combine masks.
321
336
auto rowIsActive = rewriter.create <arith::CmpIOp>(
@@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
344
359
/* passthru=*/ pad1DOp);
345
360
346
361
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
347
- tileLoadOp.createOpAndForwardTileId <arm_sme::MoveVectorToTileSliceOp>(
348
- rewriter, loc, tileType, loadSlice->getResult (0 ), tile, tileSliceIndex,
349
- tileLoadOp.getLayout ());
362
+ auto moveSlice =
363
+ tileLoadOp.createOpAndForwardTileId <arm_sme::MoveVectorToTileSliceOp>(
364
+ rewriter, loc, tileType, loadSlice->getResult (0 ), currentTile,
365
+ tileSliceIndex, tileLoadOp.getLayout ());
366
+ rewriter.create <scf::YieldOp>(loc, moveSlice.getResult ());
350
367
351
368
rewriter.setInsertionPointAfter (forOp);
352
369
353
- // Replace 'arm_sme.tile_load' with the tile .
354
- rewriter.replaceOp (tileLoadOp, tile );
370
+ // Replace 'arm_sme.tile_load' with the result .
371
+ rewriter.replaceOp (tileLoadOp, forOp. getResult ( 0 ) );
355
372
356
373
return success ();
357
374
}
0 commit comments