@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
80
80
LogicalResult matchAndRewrite (arm_sme::TileLoadOp tileLoadOp,
81
81
PatternRewriter &rewriter) const override {
82
82
if (tileLoadOp.getMask ())
83
- // TODO: add masked patterns.
84
83
return rewriter.notifyMatchFailure (
85
- tileLoadOp, " op has mask, needs masked pattern(s) " );
84
+ tileLoadOp, " op has mask, apply masked patterns " );
86
85
87
86
OpBuilder::InsertionGuard g (rewriter);
88
87
auto loc = tileLoadOp.getLoc ();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
142
141
}
143
142
};
144
143
144
+ // / Lower `arm_sme.tile_load` with mask and pad of constant zero.
145
+ // /
146
+ // / BEFORE:
147
+ // / ```mlir
148
+ // / %pad = arith.constant 0 : i32
149
+ // / %num_rows = arith.constant 2 : index
150
+ // / %num_cols = arith.constant 4 : index
151
+ // / %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
152
+ // / %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
153
+ // / memref<?x?xi32>, vector<[4]x[4]xi32>
154
+ // / ```
155
+ // /
156
+ // / AFTER:
157
+ // / ```mlir
158
+ // / %c0 = arith.constant 0 : index
159
+ // / %c1 = arith.constant 1 : index
160
+ // / %tile = arm_sme.zero : vector<[4]x[4]xi32>
161
+ // / %num_cols = vector.create_mask %c4 : vector<[4]xi1>
162
+ // / scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
163
+ // / %tile_update = arm_sme.load_tile_slice
164
+ // / %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
165
+ // / memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
166
+ // / }
167
+ // / ```
168
+ // /
169
+ // / NOTE: Only mask of 'vector.create_mask' op is currently supported.
170
+ struct TileLoadOpWithMaskAndPadZeroConversion
171
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
172
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
173
+
174
+ LogicalResult matchAndRewrite (arm_sme::TileLoadOp tileLoadOp,
175
+ PatternRewriter &rewriter) const override {
176
+ OpBuilder::InsertionGuard g (rewriter);
177
+ auto loc = tileLoadOp.getLoc ();
178
+ auto tileType = tileLoadOp.getVectorType ();
179
+
180
+ auto maskOp = tileLoadOp.getMask ();
181
+ if (!maskOp)
182
+ return rewriter.notifyMatchFailure (
183
+ tileLoadOp, " op has no mask, needs unmasked pattern" );
184
+
185
+ auto padOp = tileLoadOp.getPadding ();
186
+ assert (padOp && " expected padding when masking!" );
187
+
188
+ auto createMaskOp = maskOp.getDefiningOp <vector::CreateMaskOp>();
189
+ if (!createMaskOp)
190
+ return rewriter.notifyMatchFailure (
191
+ tileLoadOp, " unsupported mask op, only 'vector.create_mask' is "
192
+ " currently supported" );
193
+
194
+ auto constPadOp = padOp.getDefiningOp <arith::ConstantOp>();
195
+ if (!constPadOp || constPadOp.getValue () !=
196
+ rewriter.getZeroAttr (tileType.getElementType ()))
197
+ return rewriter.notifyMatchFailure (
198
+ tileLoadOp, " op has non-zero pad, needs non-zero pad pattern" );
199
+
200
+ auto numRows = createMaskOp.getOperands ()[0 ];
201
+ auto numCols = createMaskOp.getOperands ()[1 ];
202
+
203
+ auto predicateType =
204
+ VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
205
+ auto numColsOp =
206
+ rewriter.create <vector::CreateMaskOp>(loc, predicateType, numCols);
207
+
208
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
209
+ // zeroed anyway since the loads use zeroing predication. For inactive rows
210
+ // however, no load will occur so these need to be zeroed.
211
+ auto tile = rewriter.create <arm_sme::ZeroOp>(loc, tileType);
212
+
213
+ // Create a loop to load the active tile slices from memory.
214
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
215
+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
216
+ auto upperBound = numRows;
217
+ auto forOp = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
218
+
219
+ rewriter.setInsertionPointToStart (forOp.getBody ());
220
+
221
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
222
+ // tile.
223
+ SmallVector<Value> memrefIndices;
224
+ auto tileSliceIndex = forOp.getInductionVar ();
225
+ getMemrefIndices (tileLoadOp.getIndices (),
226
+ tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
227
+ upperBound, memrefIndices, loc, rewriter);
228
+ rewriter.create <arm_sme::LoadTileSliceOp>(
229
+ loc, tileType, tileLoadOp.getBase (), numColsOp, tile, memrefIndices,
230
+ tileSliceIndex, tileLoadOp.getLayout ());
231
+
232
+ rewriter.setInsertionPointAfter (forOp);
233
+
234
+ // Replace 'arm_sme.tile_load' with the tile.
235
+ rewriter.replaceOp (tileLoadOp, tile);
236
+
237
+ return success ();
238
+ }
239
+ };
240
+
241
+ // / Lower `arm_sme.tile_load` with mask and non-zero pad.
242
+ // /
243
+ // / BEFORE:
244
+ // / ```mlir
245
+ // / %pad = arith.constant 1 : i32
246
+ // / %num_rows = arith.constant 2 : index
247
+ // / %num_cols = arith.constant 4 : index
248
+ // / %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
249
+ // / %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
250
+ // / memref<?x?xi32>, vector<[4]x[4]xi32>
251
+ // / ```
252
+ // /
253
+ // / AFTER:
254
+ // / ```mlir
255
+ // / %pad_1d = arith.constant dense<1> : vector<[4]xi32>
256
+ // / %num_rows = arith.constant 2 : index
257
+ // / %num_cols = arith.constant 4 : index
258
+ // / %tile_id = arm_sme.get_tile_id : i32
259
+ // / %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
260
+ // / %vscale = vector.vscale
261
+ // / %c0 = arith.constant 0 : index
262
+ // / %c1 = arith.constant 1 : index
263
+ // / %min_svl_s = arith.constant 4 : index
264
+ // / %svl_s = arith.muli %min_svl_s, %vscale : index
265
+ // / scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
266
+ // / %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
267
+ // / %slice = scf.if %row_is_active -> vector<[4]xi32> {
268
+ // / %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
269
+ // / : memref<?x?xi32>, vector<[4]xi1>,
270
+ // / vector<[4]xi32> into vector<[4]xi32>
271
+ // / scf.yield %slice : vector<[4]xi32>
272
+ // / } else {
273
+ // / scf.yield %pad_1d : vector<[4]xi32>
274
+ // / }
275
+ // / // Insert slice into tile
276
+ // / arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
277
+ // / : vector<[4]xi32> into vector<[4]x[4]xi32>
278
+ // / }
279
+ // / ```
280
+ struct TileLoadOpWithMaskAndPadNonZeroConversion
281
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
282
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
283
+
284
+ LogicalResult matchAndRewrite (arm_sme::TileLoadOp tileLoadOp,
285
+ PatternRewriter &rewriter) const override {
286
+ OpBuilder::InsertionGuard g (rewriter);
287
+ auto loc = tileLoadOp.getLoc ();
288
+ auto tileType = tileLoadOp.getVectorType ();
289
+ auto tileElementType = tileType.getElementType ();
290
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth ();
291
+
292
+ auto maskOp = tileLoadOp.getMask ();
293
+ if (!maskOp)
294
+ return rewriter.notifyMatchFailure (
295
+ tileLoadOp, " op has no mask, needs unmasked pattern" );
296
+
297
+ auto padOp = tileLoadOp.getPadding ();
298
+ assert (padOp && " expected padding when masking!" );
299
+
300
+ auto createMaskOp = maskOp.getDefiningOp <vector::CreateMaskOp>();
301
+ if (!createMaskOp)
302
+ return rewriter.notifyMatchFailure (
303
+ tileLoadOp, " unsupported mask op, only 'vector.create_mask' is "
304
+ " currently supported" );
305
+
306
+ auto constPadOp = padOp.getDefiningOp <arith::ConstantOp>();
307
+ if (constPadOp &&
308
+ constPadOp.getValue () == rewriter.getZeroAttr (tileElementType))
309
+ return rewriter.notifyMatchFailure (
310
+ tileLoadOp, " op has constant zero pad, needs zero pad pattern" );
311
+
312
+ auto numRows = createMaskOp.getOperands ()[0 ];
313
+ auto numCols = createMaskOp.getOperands ()[1 ];
314
+
315
+ VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
316
+ auto predicateType =
317
+ VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
318
+ auto numColsOp =
319
+ rewriter.create <vector::CreateMaskOp>(loc, predicateType, numCols);
320
+
321
+ // Create 'arm_sme.get_tile' op.
322
+ auto tileId = rewriter.create <arm_sme::GetTileID>(
323
+ loc, rewriter.getIntegerType (tileElementWidth));
324
+
325
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
326
+ // use as input tile to 'arm_sme.load_tile_slice' ops.
327
+ auto tile =
328
+ rewriter.create <arm_sme::CastTileToVector>(loc, tileType, tileId);
329
+
330
+ // Create a loop that loads each ZA tile slice from memory.
331
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
332
+ auto minTileSlices = rewriter.create <arith::ConstantIndexOp>(
333
+ loc, arm_sme::getSMETileSliceMinNumElts (tileElementType));
334
+ auto vscale =
335
+ rewriter.create <vector::VectorScaleOp>(loc, rewriter.getIndexType ());
336
+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
337
+ auto numTileSlices =
338
+ rewriter.create <arith::MulIOp>(loc, minTileSlices, vscale);
339
+ auto forOp =
340
+ rewriter.create <scf::ForOp>(loc, lowerBound, numTileSlices, step);
341
+
342
+ rewriter.setInsertionPointToStart (forOp.getBody ());
343
+
344
+ auto tileSliceIndex = forOp.getInductionVar ();
345
+
346
+ auto rowIsActive = rewriter.create <arith::CmpIOp>(
347
+ loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
348
+
349
+ SmallVector<Value> memrefIndices;
350
+ getMemrefIndices (tileLoadOp.getIndices (),
351
+ tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
352
+ numTileSlices, memrefIndices, loc, rewriter);
353
+
354
+ // Splat pad into 1-D vector matching type of tile slice.
355
+ auto pad1DOp = rewriter.create <vector::SplatOp>(loc, tileSliceType, padOp);
356
+
357
+ Operation *slice = rewriter.create <scf::IfOp>(
358
+ loc, rowIsActive,
359
+ [&](OpBuilder &b, Location loc) {
360
+ // If the row is active, emit a masked load where the predicate is
361
+ // 'numCols'. Pad is used for inactive elements, taken from
362
+ // passthru.
363
+ auto loadSlice = rewriter.create <vector::MaskedLoadOp>(
364
+ loc, tileSliceType, tileLoadOp.getBase (), memrefIndices,
365
+ numColsOp, /* passthru=*/ pad1DOp);
366
+ rewriter.create <scf::YieldOp>(loc, loadSlice->getResult (0 ));
367
+ },
368
+ [&](OpBuilder &b, Location loc) {
369
+ // Inactive rows are filled with pad.
370
+ rewriter.create <scf::YieldOp>(loc, pad1DOp.getResult ());
371
+ });
372
+
373
+ // TODO: If the load is vertical the transpose can't be done in-flight with
374
+ // a regular (SVE) maskedload. Propagate layout to
375
+ // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
376
+ // is currently broken.
377
+
378
+ // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
379
+ rewriter.create <arm_sme::MoveVectorToTileSliceOp>(
380
+ loc, tileType, slice->getResult (0 ), tile, tileSliceIndex,
381
+ tileLoadOp.getLayout ());
382
+
383
+ rewriter.setInsertionPointAfter (forOp);
384
+
385
+ // Replace 'arm_sme.tile_load' with the tile.
386
+ rewriter.replaceOp (tileLoadOp, tile);
387
+
388
+ return success ();
389
+ }
390
+ };
391
+
145
392
// / Lower `arm_sme.tile_store` to a loop over the tile slices and store each
146
393
// / slice using `arm_sme.store_tile_slice`.
147
394
// /
@@ -294,7 +541,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
294
541
} // namespace
295
542
296
543
void mlir::populateArmSMEToSCFConversionPatterns (RewritePatternSet &patterns) {
297
- patterns.add <TileLoadOpConversion, TileStoreOpConversion,
544
+ patterns.add <TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
545
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
298
546
TileVectorPrintOpConversion>(patterns.getContext ());
299
547
}
300
548
0 commit comments