-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings #73922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesSince #73253 loops over tiles in SSA form (i.e. loops that take Example: IR Before: scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1
{
arm_sme.move_vector_to_tile_slice
%broadcast_to_1d, %tile, %tile_slice_index :
vector<[4]xi32> into vector<[4]x[4]xi32>
}
// ... later use %tile IR Now: %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
{
%tile_update = arm_sme.move_vector_to_tile_slice
%broadcast_to_1d, %iter_tile, %tile_slice_index :
vector<[4]xi32> into vector<[4]x[4]xi32>
scf.yield %tile_update : vector<[4]x[4]xi32>
}
// ... later use %broadcast_to_tile Patch is 28.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73922.diff 5 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6d5..72b476c9f049537 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// AFTER:
/// ```mlir
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-/// %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
/// %vscale = vector.vscale
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-/// %ptrue_s, %tile, %tile_slice_idx
+/// %ptrue_s, %iter_tile, %tile_slice_idx
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
auto tileElementType = tileType.getElementType();
// Allocate a new SME tile.
- auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+ auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
@@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ auto currentTile = forOp.getRegionIterArg(0);
+ auto loadSlice =
+ tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
+ currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
rewriter.setInsertionPointAfter(forOp);
- // Replace 'arm_sme.tile_load' with the tile.
- rewriter.replaceOp(tileLoadOp, tile);
+ // Replace 'arm_sme.tile_load' with the result.
+ rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
@@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
-/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
+/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
/// %num_rows = arith.constant 2 : index
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
-/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+/// %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
+/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice
-/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
///
@@ -202,14 +209,15 @@ struct TileLoadOpWithMaskAndPadZeroConversion
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive rows
// however, no load will occur so these need to be zeroed.
- auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+ auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
rewriter, loc, tileType);
// Create a loop to load the active tile slices from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = numRows;
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+ ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -217,17 +225,20 @@ struct TileLoadOpWithMaskAndPadZeroConversion
// tile.
SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
upperBound, memrefIndices, loc, rewriter);
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ auto loadSlice =
+ tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
+ currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
rewriter.setInsertionPointAfter(forOp);
- // Replace 'arm_sme.tile_load' with the tile.
- rewriter.replaceOp(tileLoadOp, tile);
+ // Replace 'arm_sme.tile_load' with the result.
+ rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
@@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
/// ```mlir
/// ...
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
-/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
/// : memref<?x?xi32>, vector<[4]xi1>,
/// vector<[4]xi32> into vector<[4]xi32>
/// // Insert slice into tile
-/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
-/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// %tile_update = arm_sme.move_vector_to_tile_slice
+/// %slice, %iter_tile, %tile_slice_idx :
+/// vector<[4]xi32> into vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpWithMaskAndPadNonZeroConversion
@@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
loc, rewriter.getI32Type(), numCols);
// Allocate a new SME tile.
- auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+ auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
@@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
// Combine masks.
auto rowIsActive = rewriter.create<arith::CmpIOp>(
@@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
/*passthru=*/pad1DOp);
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
- tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
- rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
- tileLoadOp.getLayout());
+ auto moveSlice =
+ tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+ rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+ tileSliceIndex, tileLoadOp.getLayout());
+ rewriter.create<scf::YieldOp>(loc, ValueRange{moveSlice});
rewriter.setInsertionPointAfter(forOp);
- // Replace 'arm_sme.tile_load' with the tile.
- rewriter.replaceOp(tileLoadOp, tile);
+ // Replace 'arm_sme.tile_load' with the result.
+ rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a84772d..250c9914b8c2823 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -26,21 +26,26 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
}
/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index. Sets the IR Builder insertion point as the loop body.
-/// Callers of this method are responsible for restoring it if needed.
-static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
- Type eltType) {
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via the callback, which returns the next tile value.
+template <typename LoopBodyCallback>
+static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
+ Location loc, Value initTile,
+ LoopBodyCallback callback) {
+ OpBuilder::InsertionGuard g(rewriter);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(eltType));
+ loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+ ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
+ auto nextTile = callback(forOp);
+ rewriter.create<scf::YieldOp>(loc, ValueRange{nextTile});
return forOp;
}
@@ -242,27 +247,25 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
// Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
// ops that broadcast the constant to each tile slice.
- OpBuilder::InsertionGuard g(rewriter);
auto loc = constantOp.getLoc();
// Unpack 1-d vector type from 2-d vector type.
- auto tileSliceType =
- VectorType::get(tileType.getShape().drop_front(), tileElementType,
- /*scalableDims=*/{true});
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
- auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-
- auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
- auto tileSliceIndex = forOp.getInductionVar();
-
- // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, constantOp1D, tile, tileSliceIndex);
-
- rewriter.replaceOp(constantOp, tile);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto forOp =
+ createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+ auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+ // slice.
+ return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ });
+ rewriter.replaceOp(constantOp, forOp.getResult(0));
return success();
}
@@ -277,9 +280,13 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+/// {
+/// %tile_update = arm_sme.move_vector_to_tile_slice
+/// %broadcast_to_1d, %iter_tile, %tile_slice_index :
+/// vector<[4]xi32> into vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
///
/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
@@ -293,20 +300,16 @@ struct BroadcastOpToArmSMELowering
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- OpBuilder::InsertionGuard g(rewriter);
auto loc = broadcastOp.getLoc();
auto srcType = broadcastOp.getSourceType();
auto srcVectorType = dyn_cast<VectorType>(srcType);
- auto tileElementType = tileType.getElementType();
Value broadcastOp1D;
if (srcType.isIntOrFloat() ||
(srcVectorType && (srcVectorType.getRank() == 0))) {
// Broadcast scalar or 0-d vector to 1-d vector.
- auto tileSliceType =
- VectorType::get(tileType.getShape().drop_front(), tileElementType,
- /*scalableDims=*/{true});
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, broadcastOp.getSource());
} else if (srcVectorType && (srcVectorType.getRank() == 1))
@@ -315,18 +318,20 @@ struct BroadcastOpToArmSMELowering
else
return failure();
- auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
// Create a loop over ZA tile slices.
- auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
- auto tileSliceIndex = forOp.getInductionVar();
-
- // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
- // tile slice.
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, tile, tileSliceIndex);
-
- rewriter.replaceOp(broadcastOp, tile);
+ auto forOp =
+ createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+ auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
+ // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+ // to each tile slice.
+ return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ });
+
+ rewriter.replaceOp(broadcastOp, forOp.getResult(0));
return success();
}
@@ -341,9 +346,13 @@ struct BroadcastOpToArmSMELowering
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+/// {
+/// %tile_update = arm_sme.move_vector_to_tile_slice
+/// %broadcast_to_1d, %iter_tile, %tile_slice_index :
+/// vector<[4]xi32> into vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
///
/// This is identical to vector.broadcast of a scalar.
@@ -356,11 +365,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- OpBuilder::InsertionGuard g(rewriter);
auto loc = splatOp.getLoc();
-
auto srcType = splatOp.getOperand().getType();
- auto tileElementType = tileType.getElementType();
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
// Avoid unused-variable warning when building without assertions.
@@ -371,17 +377,19 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, splatOp.getInput());
- auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
- auto tileSliceIndex = forOp.getInductionVar();
-
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+ auto forOp =
+ createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+ auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
+ return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ });
- rewriter.replaceOp(splatOp, tile);
+ rewriter.replaceOp(splatOp, forOp.getResult(0));
return success();
}
@@ -424,7 +432,6 @@ struct TransposeOpToArmSMELowering
if (permutation[0] != 1 || permutation[1] != 0)
return failure();
- OpBuilder::InsertionGuard g(rewriter);
auto loc = transposeOp.getLoc();
// Allocate buffer to store input tile to.
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7c0..9fe80192809f310 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,16 +6,17 @@
// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK-DAG: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vecto...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for patch Ben this is a nice improvement
auto forOp = | ||
createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) { | ||
auto tileSliceIndex = forOp.getInductionVar(); | ||
auto currentTile = forOp.getRegionIterArg(0); | ||
// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile | ||
// slice. | ||
return rewriter.create<arm_sme::MoveVectorToTileSliceOp>( | ||
loc, tileType, constantOp1D, currentTile, tileSliceIndex); | ||
}); | ||
rewriter.replaceOp(constantOp, forOp.getResult(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is nice! Did clang-format format this? Indentation looks more than I'd expect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this is clang-format
😔
c65cf08
to
c07d3e4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice, thanks! LGTM (feel free to ignore my nits)
|
||
rewriter.setInsertionPointToStart(forOp.getBody()); | ||
|
||
// Create 'arm_sme.load_tile_slice' to load tile slice from memory into | ||
// tile. | ||
SmallVector<Value> memrefIndices; | ||
auto tileSliceIndex = forOp.getInductionVar(); | ||
auto currentTile = forOp.getRegionIterArg(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] updatedTile
? outputTile
? "current" is a relative term (current with respect to what?), so not a fan :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming is tricky, but it's the tile at the start of the current iteration of the loop, which is taken from the iter_args
. The updatedTile
would be the tile you get from the operation in the loop (i.e. the result of move_vector_to_tile_slice
), and outputTile
the result of the scf.for
(at least that's how I see it :))
Since llvm#73253 we now loops in SSA form for tiles (i.e. loops that take `iter_args` and yield a new tile), so this patch updates lowerings to use that. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.
c07d3e4
to
c8411d3
Compare
Since #73253, loops over tiles in SSA form (i.e. loops that take
iter_args
and yield a new tile) are supported, so this patch updates ArmSME lowerings to this form. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.Example:
IR before:
IR now: