@@ -255,6 +255,7 @@ struct TileLoadOpWithMaskAndPadZeroConversion
255
255
// / %pad_1d = arith.constant dense<1> : vector<[4]xi32>
256
256
// / %num_rows = arith.constant 2 : index
257
257
// / %num_cols = arith.constant 4 : index
258
+ // / %num_cols_i32 = arith.index_castui %num_cols : index to i32
258
259
// / %tile_id = arm_sme.get_tile_id : i32
259
260
// / %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
260
261
// / %vscale = vector.vscale
@@ -264,14 +265,13 @@ struct TileLoadOpWithMaskAndPadZeroConversion
264
265
// / %svl_s = arith.muli %min_svl_s, %vscale : index
265
266
// / scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
266
267
// / %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
267
- // / %slice = scf.if %row_is_active -> vector<[4]xi32> {
268
- // / %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
269
- // / : memref<?x?xi32>, vector<[4]xi1>,
270
- // / vector<[4]xi32> into vector<[4]xi32>
271
- // / scf.yield %slice : vector<[4]xi32>
272
- // / } else {
273
- // / scf.yield %pad_1d : vector<[4]xi32>
274
- // / }
268
+ // / %row_is_active_i32 = arith.extsi %row_is_active : i1 to i32
269
+ // / %mask = arith.andi %row_is_active_i32, %num_cols_i32 : i32
270
+ // / %mask_index = arith.index_cast %mask : i32 to index
271
+ // / %mask_1d = vector.create_mask %mask_index : vector<[4]xi1>
272
+ // / %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad
273
+ // / : memref<?x?xi32>, vector<[4]xi1>,
274
+ // / vector<[4]xi32> into vector<[4]xi32>
275
275
// / // Insert slice into tile
276
276
// / arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
277
277
// / : vector<[4]xi32> into vector<[4]x[4]xi32>
@@ -312,11 +312,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
312
312
auto numRows = createMaskOp.getOperands ()[0 ];
313
313
auto numCols = createMaskOp.getOperands ()[1 ];
314
314
315
- VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
316
- auto predicateType =
317
- VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
318
- auto numColsOp =
319
- rewriter.create <vector::CreateMaskOp>(loc, predicateType, numCols);
315
+ auto numColsI32 = rewriter.create <arith::IndexCastUIOp>(
316
+ loc, rewriter.getI32Type (), numCols);
320
317
321
318
// Create 'arm_sme.get_tile' op.
322
319
auto tileId = rewriter.create <arm_sme::GetTileID>(
@@ -343,41 +340,35 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
343
340
344
341
auto tileSliceIndex = forOp.getInductionVar ();
345
342
343
+ // Combine masks.
346
344
auto rowIsActive = rewriter.create <arith::CmpIOp>(
347
345
loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
346
+ auto rowIsActiveI32 = rewriter.create <arith::ExtSIOp>(
347
+ loc, rewriter.getI32Type (), rowIsActive);
348
+ auto mask = rewriter.create <arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
349
+ auto maskIndex =
350
+ rewriter.create <arith::IndexCastOp>(loc, rewriter.getIndexType (), mask);
351
+ auto predicateType =
352
+ VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
353
+ auto maskOp1D = rewriter.create <vector::CreateMaskOp>(
354
+ loc, predicateType, maskIndex.getResult ());
348
355
349
356
SmallVector<Value> memrefIndices;
350
357
getMemrefIndices (tileLoadOp.getIndices (),
351
358
tileLoadOp.getMemRefType ().getRank (), tileSliceIndex,
352
359
numTileSlices, memrefIndices, loc, rewriter);
353
360
354
361
// Splat pad into 1-D vector matching type of tile slice.
362
+ VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
355
363
auto pad1DOp = rewriter.create <vector::SplatOp>(loc, tileSliceType, padOp);
356
364
357
- Operation *slice = rewriter.create <scf::IfOp>(
358
- loc, rowIsActive,
359
- [&](OpBuilder &b, Location loc) {
360
- // If the row is active, emit a masked load where the predicate is
361
- // 'numCols'. Pad is used for inactive elements, taken from
362
- // passthru.
363
- auto loadSlice = rewriter.create <vector::MaskedLoadOp>(
364
- loc, tileSliceType, tileLoadOp.getBase (), memrefIndices,
365
- numColsOp, /* passthru=*/ pad1DOp);
366
- rewriter.create <scf::YieldOp>(loc, loadSlice->getResult (0 ));
367
- },
368
- [&](OpBuilder &b, Location loc) {
369
- // Inactive rows are filled with pad.
370
- rewriter.create <scf::YieldOp>(loc, pad1DOp.getResult ());
371
- });
372
-
373
- // TODO: If the load is vertical the transpose can't be done in-flight with
374
- // a regular (SVE) maskedload. Propagate layout to
375
- // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
376
- // is currently broken.
365
+ auto loadSlice = rewriter.create <vector::MaskedLoadOp>(
366
+ loc, tileSliceType, tileLoadOp.getBase (), memrefIndices, maskOp1D,
367
+ /* passthru=*/ pad1DOp);
377
368
378
369
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
379
370
rewriter.create <arm_sme::MoveVectorToTileSliceOp>(
380
- loc, tileType, slice ->getResult (0 ), tile, tileSliceIndex,
371
+ loc, tileType, loadSlice ->getResult (0 ), tile, tileSliceIndex,
381
372
tileLoadOp.getLayout ());
382
373
383
374
rewriter.setInsertionPointAfter (forOp);
0 commit comments