Skip to content

[mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings #73922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 6, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 30, 2023

Since #73253, loops over tiles in SSA form (i.e. loops that take iter_args and yield a new tile) are supported, so this patch updates ArmSME lowerings to this form. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.

Example:

IR before:

scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 
{
   arm_sme.move_vector_to_tile_slice 
     %broadcast_to_1d, %tile, %tile_slice_index : 
     vector<[4]xi32> into vector<[4]x[4]xi32>
}
// ... later use %tile

IR now:

%broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
    step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
{
   %tile_update = arm_sme.move_vector_to_tile_slice
      %broadcast_to_1d, %iter_tile, %tile_slice_index :
      vector<[4]xi32> into vector<[4]x[4]xi32>
  scf.yield %tile_update : vector<[4]x[4]xi32>
}
// ... later use %broadcast_to_tile

@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2023

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Since #73253 loops over tiles in SSA form (i.e. loops that take iter_args and yield a new tile) are supported, so this patch updates ArmSME lowerings to this form. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.

Example:

IR Before:

scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 
{
   arm_sme.move_vector_to_tile_slice 
     %broadcast_to_1d, %tile, %tile_slice_index : 
     vector&lt;[4]xi32&gt; into vector&lt;[4]x[4]xi32&gt;
}
// ... later use %tile

IR Now:

%broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
    step %c1 iter_args(%iter_tile = %init_tile) -&gt; (vector&lt;[4]x[4]xi32&gt;)
{
   %tile_update = arm_sme.move_vector_to_tile_slice
      %broadcast_to_1d, %iter_tile, %tile_slice_index :
      vector&lt;[4]xi32&gt; into vector&lt;[4]x[4]xi32&gt;
  scf.yield %tile_update : vector&lt;[4]x[4]xi32&gt;
}
// ... later use %broadcast_to_tile

Patch is 28.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73922.diff

5 Files Affected:

  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+49-32)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+60-53)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+11-8)
  • (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+8-6)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+4-3)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6d5..72b476c9f049537 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  AFTER:
 ///  ```mlir
 ///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-///  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+///  %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
-///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-///      %ptrue_s, %tile, %tile_slice_idx
+///      %ptrue_s, %iter_tile, %tile_slice_idx
 ///        : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+///    scf.yield %tile_update : vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto tileElementType = tileType.getElementType();
 
     // Allocate a new SME tile.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
         rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
@@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     // ..., SVL_Q).
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+                                             step, ValueRange{initTile});
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
@@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-        rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
-        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    auto currentTile = forOp.getRegionIterArg(0);
+    auto loadSlice =
+        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+            rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
+            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
 
     rewriter.setInsertionPointAfter(forOp);
 
-    // Replace 'arm_sme.tile_load' with the tile.
-    rewriter.replaceOp(tileLoadOp, tile);
+    // Replace 'arm_sme.tile_load' with the result.
+    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
 
     return success();
   }
@@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  ```mlir
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
-///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
 ///  %num_rows = arith.constant 2 : index
 ///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
-///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///  %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
+///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    %tile_update = arm_sme.load_tile_slice
-///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
 ///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///    scf.yield %tile_update : vector<[4]x[4]xi32>
 ///  }
 ///  ```
 ///
@@ -202,14 +209,15 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // Initialize tile with zero to satisfy padding. Inactive cols will be
     // zeroed anyway since the loads use zeroing predication. For inactive rows
     // however, no load will occur so these need to be zeroed.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
         rewriter, loc, tileType);
 
     // Create a loop to load the active tile slices from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     auto upperBound = numRows;
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+                                             ValueRange{initTile});
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
@@ -217,17 +225,20 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // tile.
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
+    auto currentTile = forOp.getRegionIterArg(0);
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-        rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
-        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    auto loadSlice =
+        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+            rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
+            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
 
     rewriter.setInsertionPointAfter(forOp);
 
-    // Replace 'arm_sme.tile_load' with the tile.
-    rewriter.replaceOp(tileLoadOp, tile);
+    // Replace 'arm_sme.tile_load' with the result.
+    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
 
     return success();
   }
@@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
 ///  ```mlir
 ///  ...
 ///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
-///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    ...
 ///    %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
 ///    %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
 ///      : memref<?x?xi32>, vector<[4]xi1>,
 ///        vector<[4]xi32> into vector<[4]xi32>
 ///    // Insert slice into tile
-///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
-///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///    %tile_update = arm_sme.move_vector_to_tile_slice
+///      %slice, %iter_tile, %tile_slice_idx :
+///      vector<[4]xi32> into vector<[4]x[4]xi32>
+///    scf.yield %tile_update : vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileLoadOpWithMaskAndPadNonZeroConversion
@@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         loc, rewriter.getI32Type(), numCols);
 
     // Allocate a new SME tile.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
         rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
@@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+                                             step, ValueRange{initTile});
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
     auto tileSliceIndex = forOp.getInductionVar();
+    auto currentTile = forOp.getRegionIterArg(0);
 
     // Combine masks.
     auto rowIsActive = rewriter.create<arith::CmpIOp>(
@@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
-        rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
-        tileLoadOp.getLayout());
+    auto moveSlice =
+        tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+            rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+            tileSliceIndex, tileLoadOp.getLayout());
+    rewriter.create<scf::YieldOp>(loc, ValueRange{moveSlice});
 
     rewriter.setInsertionPointAfter(forOp);
 
-    // Replace 'arm_sme.tile_load' with the tile.
-    rewriter.replaceOp(tileLoadOp, tile);
+    // Replace 'arm_sme.tile_load' with the result.
+    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
 
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a84772d..250c9914b8c2823 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -26,21 +26,26 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
 }
 
 /// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index. Sets the IR Builder insertion point as the loop body.
-/// Callers of this method are responsible for restoring it if needed.
-static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
-                                        Type eltType) {
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via the callback, which returns the next tile value.
+template <typename LoopBodyCallback>
+static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
+                                           Location loc, Value initTile,
+                                           LoopBodyCallback callback) {
+  OpBuilder::InsertionGuard g(rewriter);
   auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-      loc, arm_sme::getSMETileSliceMinNumElts(eltType));
+      loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
   auto vscale =
       rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
   auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto numTileSlices =
       rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-  auto forOp =
-      rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+                                           ValueRange{initTile});
   rewriter.setInsertionPointToStart(forOp.getBody());
+  auto nextTile = callback(forOp);
+  rewriter.create<scf::YieldOp>(loc, ValueRange{nextTile});
   return forOp;
 }
 
@@ -242,27 +247,25 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
 
     // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
     // ops that broadcast the constant to each tile slice.
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = constantOp.getLoc();
 
     // Unpack 1-d vector type from 2-d vector type.
-    auto tileSliceType =
-        VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                        /*scalableDims=*/{true});
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
     auto denseAttr1D = DenseElementsAttr::get(
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-
-    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
-    auto tileSliceIndex = forOp.getInductionVar();
-
-    // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, constantOp1D, tile, tileSliceIndex);
-
-    rewriter.replaceOp(constantOp, tile);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    auto forOp =
+        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+          auto tileSliceIndex = forOp.getInductionVar();
+          auto currentTile = forOp.getRegionIterArg(0);
+          // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+          // slice.
+          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+              loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+        });
+    rewriter.replaceOp(constantOp, forOp.getResult(0));
 
     return success();
   }
@@ -277,9 +280,13 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
 /// is converted to:
 ///
 ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+///   {
+///     %tile_update = arm_sme.move_vector_to_tile_slice
+///        %broadcast_to_1d, %iter_tile, %tile_slice_index :
+///        vector<[4]xi32> into vector<[4]x[4]xi32>
+///     scf.yield %tile_update : vector<[4]x[4]xi32>
 ///   }
 ///
 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
@@ -293,20 +300,16 @@ struct BroadcastOpToArmSMELowering
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = broadcastOp.getLoc();
 
     auto srcType = broadcastOp.getSourceType();
     auto srcVectorType = dyn_cast<VectorType>(srcType);
-    auto tileElementType = tileType.getElementType();
 
     Value broadcastOp1D;
     if (srcType.isIntOrFloat() ||
         (srcVectorType && (srcVectorType.getRank() == 0))) {
       // Broadcast scalar or 0-d vector to 1-d vector.
-      auto tileSliceType =
-          VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                          /*scalableDims=*/{true});
+      VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
       broadcastOp1D = rewriter.create<vector::BroadcastOp>(
           loc, tileSliceType, broadcastOp.getSource());
     } else if (srcVectorType && (srcVectorType.getRank() == 1))
@@ -315,18 +318,20 @@ struct BroadcastOpToArmSMELowering
     else
       return failure();
 
-    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Create a loop over ZA tile slices.
-    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
-    auto tileSliceIndex = forOp.getInductionVar();
-
-    // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
-    // tile slice.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
-
-    rewriter.replaceOp(broadcastOp, tile);
+    auto forOp =
+        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+          auto tileSliceIndex = forOp.getInductionVar();
+          auto currentTile = forOp.getRegionIterArg(0);
+          // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+          // to each tile slice.
+          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+        });
+
+    rewriter.replaceOp(broadcastOp, forOp.getResult(0));
 
     return success();
   }
@@ -341,9 +346,13 @@ struct BroadcastOpToArmSMELowering
 /// is converted to:
 ///
 ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+///   {
+///     %tile_update = arm_sme.move_vector_to_tile_slice
+///        %broadcast_to_1d, %iter_tile, %tile_slice_index :
+///        vector<[4]xi32> into vector<[4]x[4]xi32>
+///     scf.yield %tile_update : vector<[4]x[4]xi32>
 ///   }
 ///
 /// This is identical to vector.broadcast of a scalar.
@@ -356,11 +365,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = splatOp.getLoc();
-
     auto srcType = splatOp.getOperand().getType();
-    auto tileElementType = tileType.getElementType();
 
     assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
     // Avoid unused-variable warning when building without assertions.
@@ -371,17 +377,19 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
         loc, tileSliceType, splatOp.getInput());
 
-    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
-    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
-    auto tileSliceIndex = forOp.getInductionVar();
-
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+    auto forOp =
+        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+          auto tileSliceIndex = forOp.getInductionVar();
+          auto currentTile = forOp.getRegionIterArg(0);
+          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+        });
 
-    rewriter.replaceOp(splatOp, tile);
+    rewriter.replaceOp(splatOp, forOp.getResult(0));
 
     return success();
   }
@@ -424,7 +432,6 @@ struct TransposeOpToArmSMELowering
     if (permutation[0] != 1 || permutation[1] != 0)
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = transposeOp.getLoc();
 
     // Allocate buffer to store input tile to.
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7c0..9fe80192809f310 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,16 +6,17 @@
 
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vecto...
[truncated]

@MacDue MacDue requested a review from dcaballe November 30, 2023 10:54
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for patch Ben this is a nice improvement

Comment on lines +259 to +270
auto forOp =
createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
auto tileSliceIndex = forOp.getInductionVar();
auto currentTile = forOp.getRegionIterArg(0);
// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
// slice.
return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
});
rewriter.replaceOp(constantOp, forOp.getResult(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nice! Did clang-format format this? Indentation looks more than I'd expect

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is clang-format 😔

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice, thanks! LGTM (feel free to ignore my nits)


rewriter.setInsertionPointToStart(forOp.getBody());

// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
// tile.
SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
auto currentTile = forOp.getRegionIterArg(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] updatedTile? outputTile? "current" is a relative term (current with respect to what?), so not a fan :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is tricky, but it's the tile at the start of the current iteration of the loop, which is taken from the iter_args. The updatedTile would be the tile you get from the operation in the loop (i.e. the result of move_vector_to_tile_slice), and outputTile the result of the scf.for (at least that's how I see it :))

Since llvm#73253 we now loops in SSA form for tiles (i.e. loops that take
`iter_args` and yield a new tile), so this patch updates lowerings to
use that. This is a NFC, as it still lowers to the same intrinsics,
but this makes IR less 'surprising' at a higher-level, and may be
recognised by more transforms.
@MacDue MacDue force-pushed the arm_sme_scf_yield branch from c07d3e4 to c8411d3 Compare December 6, 2023 13:09
@MacDue MacDue merged commit b0b69fd into llvm:main Dec 6, 2023
@MacDue MacDue deleted the arm_sme_scf_yield branch December 6, 2023 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants