Skip to content

[mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings #73922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// AFTER:
/// ```mlir
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
/// %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
/// %vscale = vector.vscale
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
/// %ptrue_s, %tile, %tile_slice_idx
/// %ptrue_s, %iter_tile, %tile_slice_idx
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Expand All @@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
auto tileElementType = tileType.getElementType();

// Allocate a new SME tile.
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
rewriter, loc, tileType);

// Create a loop that loads each ZA tile slice from memory.
Expand All @@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
step, ValueRange{initTile});

rewriter.setInsertionPointToStart(forOp.getBody());

Expand All @@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
auto currentTile = forOp.getRegionIterArg(0);
auto loadSlice =
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());

rewriter.setInsertionPointAfter(forOp);

// Replace 'arm_sme.tile_load' with the tile.
rewriter.replaceOp(tileLoadOp, tile);
// Replace 'arm_sme.tile_load' with the result.
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));

return success();
}
Expand All @@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
/// %num_rows = arith.constant 2 : index
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
/// %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice
/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
///
Expand Down Expand Up @@ -202,32 +209,36 @@ struct TileLoadOpWithMaskAndPadZeroConversion
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive rows
// however, no load will occur so these need to be zeroed.
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
rewriter, loc, tileType);

// Create a loop to load the active tile slices from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = numRows;
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
ValueRange{initTile});

rewriter.setInsertionPointToStart(forOp.getBody());

// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
// tile.
SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
auto currentTile = forOp.getRegionIterArg(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] updatedTile? outputTile? "current" is a relative term (current with respect to what?), so not a fan :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is tricky, but it's the tile at the start of the current iteration of the loop, which is taken from the iter_args. The updatedTile would be the tile you get from the operation in the loop (i.e. the result of move_vector_to_tile_slice), and outputTile the result of the scf.for (at least that's how I see it :))

getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
upperBound, memrefIndices, loc, rewriter);
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
auto loadSlice =
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());

rewriter.setInsertionPointAfter(forOp);

// Replace 'arm_sme.tile_load' with the tile.
rewriter.replaceOp(tileLoadOp, tile);
// Replace 'arm_sme.tile_load' with the result.
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));

return success();
}
Expand All @@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
/// ```mlir
/// ...
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
/// : memref<?x?xi32>, vector<[4]xi1>,
/// vector<[4]xi32> into vector<[4]xi32>
/// // Insert slice into tile
/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
/// %tile_update = arm_sme.move_vector_to_tile_slice
/// %slice, %iter_tile, %tile_slice_idx :
/// vector<[4]xi32> into vector<[4]x[4]xi32>
/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpWithMaskAndPadNonZeroConversion
Expand Down Expand Up @@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
loc, rewriter.getI32Type(), numCols);

// Allocate a new SME tile.
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
rewriter, loc, tileType);

// Create a loop that loads each ZA tile slice from memory.
Expand All @@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
step, ValueRange{initTile});

rewriter.setInsertionPointToStart(forOp.getBody());

auto tileSliceIndex = forOp.getInductionVar();
auto currentTile = forOp.getRegionIterArg(0);

// Combine masks.
auto rowIsActive = rewriter.create<arith::CmpIOp>(
Expand Down Expand Up @@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
/*passthru=*/pad1DOp);

// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
tileLoadOp.getLayout());
auto moveSlice =
tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
tileSliceIndex, tileLoadOp.getLayout());
rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());

rewriter.setInsertionPointAfter(forOp);

// Replace 'arm_sme.tile_load' with the tile.
rewriter.replaceOp(tileLoadOp, tile);
// Replace 'arm_sme.tile_load' with the result.
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));

return success();
}
Expand Down
Loading