@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
80
80
LogicalResult matchAndRewrite (arm_sme::TileLoadOp tileLoadOp,
81
81
PatternRewriter &rewriter) const override {
82
82
if (tileLoadOp.getMask ())
83
- // TODO: add masked patterns.
84
- return rewriter.notifyMatchFailure (
85
- tileLoadOp, " op has mask, needs masked pattern(s)" );
83
+ return rewriter.notifyMatchFailure (tileLoadOp,
84
+ " op has mask, apply masked patterns" );
86
85
87
86
OpBuilder::InsertionGuard g (rewriter);
88
87
auto loc = tileLoadOp.getLoc ();
@@ -142,6 +141,234 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
142
141
}
143
142
};
144
143
144
+ // / Lower `arm_sme.tile_load` with mask and pad of constant zero.
145
+ // /
146
+ // / BEFORE:
147
+ // / ```mlir
148
+ // / %pad = arith.constant 0 : i32
149
+ // / %num_rows = arith.constant 2 : index
150
+ // / %num_cols = arith.constant 4 : index
151
+ // / %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
152
+ // / %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
153
+ // / memref<?x?xi32>, vector<[4]x[4]xi32>
154
+ // / ```
155
+ // /
156
+ // / AFTER:
157
+ // / ```mlir
158
+ // / %c0 = arith.constant 0 : index
159
+ // / %c1 = arith.constant 1 : index
160
+ // / %tile = arm_sme.zero : vector<[4]x[4]xi32>
161
+ // / %num_rows = arith.constant 2 : index
162
+ // / %num_cols = vector.create_mask %c4 : vector<[4]xi1>
163
+ // / scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
164
+ // / %tile_update = arm_sme.load_tile_slice
165
+ // / %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
166
+ // / memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
167
+ // / }
168
+ // / ```
169
+ // /
170
+ // / NOTE: Only mask of 'vector.create_mask' op is currently supported.
171
+ struct TileLoadOpWithMaskAndPadZeroConversion
172
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
173
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
174
+
175
+ LogicalResult matchAndRewrite (arm_sme::TileLoadOp tileLoadOp,
176
+ PatternRewriter &rewriter) const override {
177
+ OpBuilder::InsertionGuard g (rewriter);
178
+ auto loc = tileLoadOp.getLoc ();
179
+ auto tileType = tileLoadOp.getVectorType ();
180
+
181
+ auto maskOp = tileLoadOp.getMask ();
182
+ if (!maskOp)
183
+ return rewriter.notifyMatchFailure (
184
+ tileLoadOp, " op has no mask, needs unmasked pattern" );
185
+
186
+ auto padOp = tileLoadOp.getPadding ();
187
+ assert (padOp && " expected padding when masking!" );
188
+
189
+ auto createMaskOp = maskOp.getDefiningOp <vector::CreateMaskOp>();
190
+ if (!createMaskOp)
191
+ return rewriter.notifyMatchFailure (
192
+ tileLoadOp, " unsupported mask op, only 'vector.create_mask' is "
193
+ " currently supported" );
194
+
195
+ auto constPadOp = padOp.getDefiningOp <arith::ConstantOp>();
196
+ if (!constPadOp || constPadOp.getValue () !=
197
+ rewriter.getZeroAttr (tileType.getElementType ()))
198
+ return rewriter.notifyMatchFailure (
199
+ tileLoadOp, " op has non-zero pad, needs non-zero pad pattern" );
200
+
201
+ auto numRows = createMaskOp.getOperands ()[0 ];
202
+ auto numCols = createMaskOp.getOperands ()[1 ];
203
+
204
+ auto predicateType =
205
+ VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
206
+ auto numColsOp =
207
+ rewriter.create <vector::CreateMaskOp>(loc, predicateType, numCols);
208
+
209
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
210
+ // zeroed anyway since the loads use zeroing predication. For inactive rows
211
+ // however, no load will occur so these need to be zeroed.
212
+ auto tile = rewriter.create <arm_sme::ZeroOp>(loc, tileType);
213
+
214
+ // Create a loop to load the active tile slices from memory.
215
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
216
+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
217
+ auto upperBound = numRows;
218
+ auto forOp = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
219
+
220
+ rewriter.setInsertionPointToStart (forOp.getBody ());
221
+
222
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
223
+ // tile.
224
+ SmallVector<Value> memrefIndices;
225
+ auto tileSliceIndex = forOp.getInductionVar ();
226
+ getMemrefIndices (tileLoadOp.getIndices (),
227
+ tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
228
+ upperBound, memrefIndices, loc, rewriter);
229
+ rewriter.create <arm_sme::LoadTileSliceOp>(
230
+ loc, tileType, tileLoadOp.getBase (), numColsOp, tile, memrefIndices,
231
+ tileSliceIndex, tileLoadOp.getLayout ());
232
+
233
+ rewriter.setInsertionPointAfter (forOp);
234
+
235
+ // Replace 'arm_sme.tile_load' with the tile.
236
+ rewriter.replaceOp (tileLoadOp, tile);
237
+
238
+ return success ();
239
+ }
240
+ };
241
+
242
+ // / Lower `arm_sme.tile_load` with mask and non-zero pad.
243
+ // /
244
+ // / BEFORE:
245
+ // / ```mlir
246
+ // / %pad = arith.constant 1 : i32
247
+ // / %num_rows = arith.constant 2 : index
248
+ // / %num_cols = arith.constant 4 : index
249
+ // / %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
250
+ // / %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
251
+ // / memref<?x?xi32>, vector<[4]x[4]xi32>
252
+ // / ```
253
+ // /
254
+ // / AFTER:
255
+ // / ```mlir
256
+ // / ...
257
+ // / %pad_1d = arith.constant dense<1> : vector<[4]xi32>
258
+ // / scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
259
+ // / ...
260
+ // / %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
261
+ // / %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
262
+ // / : memref<?x?xi32>, vector<[4]xi1>,
263
+ // / vector<[4]xi32> into vector<[4]xi32>
264
+ // / // Insert slice into tile
265
+ // / arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
266
+ // / : vector<[4]xi32> into vector<[4]x[4]xi32>
267
+ // / }
268
+ // / ```
269
+ struct TileLoadOpWithMaskAndPadNonZeroConversion
270
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
271
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
272
+
273
+ LogicalResult matchAndRewrite (arm_sme::TileLoadOp tileLoadOp,
274
+ PatternRewriter &rewriter) const override {
275
+ OpBuilder::InsertionGuard g (rewriter);
276
+ auto loc = tileLoadOp.getLoc ();
277
+ auto tileType = tileLoadOp.getVectorType ();
278
+ auto tileElementType = tileType.getElementType ();
279
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth ();
280
+
281
+ auto maskOp = tileLoadOp.getMask ();
282
+ if (!maskOp)
283
+ return rewriter.notifyMatchFailure (
284
+ tileLoadOp, " op has no mask, needs unmasked pattern" );
285
+
286
+ auto padOp = tileLoadOp.getPadding ();
287
+ assert (padOp && " expected padding when masking!" );
288
+
289
+ auto createMaskOp = maskOp.getDefiningOp <vector::CreateMaskOp>();
290
+ if (!createMaskOp)
291
+ return rewriter.notifyMatchFailure (
292
+ tileLoadOp, " unsupported mask op, only 'vector.create_mask' is "
293
+ " currently supported" );
294
+
295
+ auto constPadOp = padOp.getDefiningOp <arith::ConstantOp>();
296
+ if (constPadOp &&
297
+ constPadOp.getValue () == rewriter.getZeroAttr (tileElementType))
298
+ return rewriter.notifyMatchFailure (
299
+ tileLoadOp, " op has constant zero pad, needs zero pad pattern" );
300
+
301
+ auto numRows = createMaskOp.getOperands ()[0 ];
302
+ auto numCols = createMaskOp.getOperands ()[1 ];
303
+
304
+ auto numColsI32 = rewriter.create <arith::IndexCastUIOp>(
305
+ loc, rewriter.getI32Type (), numCols);
306
+
307
+ // Create 'arm_sme.get_tile' op.
308
+ auto tileId = rewriter.create <arm_sme::GetTileID>(
309
+ loc, rewriter.getIntegerType (tileElementWidth));
310
+
311
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
312
+ // use as input tile to 'arm_sme.load_tile_slice' ops.
313
+ auto tile =
314
+ rewriter.create <arm_sme::CastTileToVector>(loc, tileType, tileId);
315
+
316
+ // Create a loop that loads each ZA tile slice from memory.
317
+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
318
+ auto minTileSlices = rewriter.create <arith::ConstantIndexOp>(
319
+ loc, arm_sme::getSMETileSliceMinNumElts (tileElementType));
320
+ auto vscale =
321
+ rewriter.create <vector::VectorScaleOp>(loc, rewriter.getIndexType ());
322
+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
323
+ auto numTileSlices =
324
+ rewriter.create <arith::MulIOp>(loc, minTileSlices, vscale);
325
+ auto forOp =
326
+ rewriter.create <scf::ForOp>(loc, lowerBound, numTileSlices, step);
327
+
328
+ rewriter.setInsertionPointToStart (forOp.getBody ());
329
+
330
+ auto tileSliceIndex = forOp.getInductionVar ();
331
+
332
+ // Combine masks.
333
+ auto rowIsActive = rewriter.create <arith::CmpIOp>(
334
+ loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
335
+ auto rowIsActiveI32 = rewriter.create <arith::ExtSIOp>(
336
+ loc, rewriter.getI32Type (), rowIsActive);
337
+ auto mask = rewriter.create <arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
338
+ auto maskIndex =
339
+ rewriter.create <arith::IndexCastOp>(loc, rewriter.getIndexType (), mask);
340
+ auto predicateType =
341
+ VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
342
+ auto maskOp1D = rewriter.create <vector::CreateMaskOp>(
343
+ loc, predicateType, maskIndex.getResult ());
344
+
345
+ SmallVector<Value> memrefIndices;
346
+ getMemrefIndices (tileLoadOp.getIndices (),
347
+ tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
348
+ numTileSlices, memrefIndices, loc, rewriter);
349
+
350
+ // Splat pad into 1-D vector matching type of tile slice.
351
+ VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
352
+ auto pad1DOp = rewriter.create <vector::SplatOp>(loc, tileSliceType, padOp);
353
+
354
+ auto loadSlice = rewriter.create <vector::MaskedLoadOp>(
355
+ loc, tileSliceType, tileLoadOp.getBase (), memrefIndices, maskOp1D,
356
+ /* passthru=*/ pad1DOp);
357
+
358
+ // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
359
+ rewriter.create <arm_sme::MoveVectorToTileSliceOp>(
360
+ loc, tileType, loadSlice->getResult (0 ), tile, tileSliceIndex,
361
+ tileLoadOp.getLayout ());
362
+
363
+ rewriter.setInsertionPointAfter (forOp);
364
+
365
+ // Replace 'arm_sme.tile_load' with the tile.
366
+ rewriter.replaceOp (tileLoadOp, tile);
367
+
368
+ return success ();
369
+ }
370
+ };
371
+
145
372
// / Lower `arm_sme.tile_store` to a loop over the tile slices and store each
146
373
// / slice using `arm_sme.store_tile_slice`.
147
374
// /
@@ -294,7 +521,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
294
521
} // namespace
295
522
296
523
void mlir::populateArmSMEToSCFConversionPatterns (RewritePatternSet &patterns) {
297
- patterns.add <TileLoadOpConversion, TileStoreOpConversion,
524
+ patterns.add <TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
525
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
298
526
TileVectorPrintOpConversion>(patterns.getContext ());
299
527
}
300
528
0 commit comments