Skip to content

Commit ec829d7

Browse files
committed
Combine masks and replace if
1 parent a376455 commit ec829d7

File tree

2 files changed

+33
-42
lines changed

2 files changed

+33
-42
lines changed

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ struct TileLoadOpWithMaskAndPadZeroConversion
255255
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
256256
/// %num_rows = arith.constant 2 : index
257257
/// %num_cols = arith.constant 4 : index
258+
/// %num_cols_i32 = arith.index_castui %num_cols : index to i32
258259
/// %tile_id = arm_sme.get_tile_id : i32
259260
/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
260261
/// %vscale = vector.vscale
@@ -264,14 +265,13 @@ struct TileLoadOpWithMaskAndPadZeroConversion
264265
/// %svl_s = arith.muli %min_svl_s, %vscale : index
265266
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
266267
/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
267-
/// %slice = scf.if %row_is_active -> vector<[4]xi32> {
268-
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
269-
/// : memref<?x?xi32>, vector<[4]xi1>,
270-
/// vector<[4]xi32> into vector<[4]xi32>
271-
/// scf.yield %slice : vector<[4]xi32>
272-
/// } else {
273-
/// scf.yield %pad_1d : vector<[4]xi32>
274-
/// }
268+
/// %row_is_active_i32 = arith.extsi %row_is_active : i1 to i32
269+
/// %mask = arith.andi %row_is_active_i32, %num_cols_i32 : i32
270+
/// %mask_index = arith.index_cast %mask : i32 to index
271+
/// %mask_1d = vector.create_mask %mask_index : vector<[4]xi1>
272+
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad
273+
/// : memref<?x?xi32>, vector<[4]xi1>,
274+
/// vector<[4]xi32> into vector<[4]xi32>
275275
/// // Insert slice into tile
276276
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
277277
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
@@ -312,11 +312,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
312312
auto numRows = createMaskOp.getOperands()[0];
313313
auto numCols = createMaskOp.getOperands()[1];
314314

315-
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
316-
auto predicateType =
317-
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
318-
auto numColsOp =
319-
rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
315+
auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
316+
loc, rewriter.getI32Type(), numCols);
320317

321318
// Create 'arm_sme.get_tile' op.
322319
auto tileId = rewriter.create<arm_sme::GetTileID>(
@@ -343,41 +340,35 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
343340

344341
auto tileSliceIndex = forOp.getInductionVar();
345342

343+
// Combine masks.
346344
auto rowIsActive = rewriter.create<arith::CmpIOp>(
347345
loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
346+
auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
347+
loc, rewriter.getI32Type(), rowIsActive);
348+
auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
349+
auto maskIndex =
350+
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
351+
auto predicateType =
352+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
353+
auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
354+
loc, predicateType, maskIndex.getResult());
348355

349356
SmallVector<Value> memrefIndices;
350357
getMemrefIndices(tileLoadOp.getIndices(),
351358
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
352359
numTileSlices, memrefIndices, loc, rewriter);
353360

354361
// Splat pad into 1-D vector matching type of tile slice.
362+
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
355363
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
356364

357-
Operation *slice = rewriter.create<scf::IfOp>(
358-
loc, rowIsActive,
359-
[&](OpBuilder &b, Location loc) {
360-
// If the row is active, emit a masked load where the predicate is
361-
// 'numCols'. Pad is used for inactive elements, taken from
362-
// passthru.
363-
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
364-
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
365-
numColsOp, /*passthru=*/pad1DOp);
366-
rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
367-
},
368-
[&](OpBuilder &b, Location loc) {
369-
// Inactive rows are filled with pad.
370-
rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
371-
});
372-
373-
// TODO: If the load is vertical the transpose can't be done in-flight with
374-
// a regular (SVE) maskedload. Propagate layout to
375-
// 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
376-
// is currently broken.
365+
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
366+
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
367+
/*passthru=*/pad1DOp);
377368

378369
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
379370
rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
380-
loc, tileType, slice->getResult(0), tile, tileSliceIndex,
371+
loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
381372
tileLoadOp.getLayout());
382373

383374
rewriter.setInsertionPointAfter(forOp);

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
6666
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
6767
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
6868
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
69-
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
69+
// CHECK-DAG: %[[NUM_COLS:.*]] = arith.constant 2 : index
70+
// CHECK-DAG: %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32
7071
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
7172
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
7273
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
7374
// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
75+
// CHECK-NEXT: %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32
76+
// CHECK-NEXT: %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32
77+
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
78+
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
7479
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
7580
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
76-
// CHECK: %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
77-
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
78-
// CHECK: scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
79-
// CHECK: } else {
80-
// CHECK: scf.yield %[[PAD_1D]] : vector<[4]xi32>
81-
// CHECK: }
82-
// CHECK: arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
81+
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
82+
// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
8383
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
8484
%c0 = arith.constant 0 : index
8585
%c2 = arith.constant 2 : index

0 commit comments

Comments
 (0)