Skip to content

[mlir][ArmSME] Lower multi-tile stores to a single loop #96187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 25, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jun 20, 2024

This adds a new pattern that can legalize a multi-tile transfer_write as a single store loop. This is done as part of type decomposition as at this level we know each tile write is disjoint, but that information is lost after decomposition (without analysis to reconstruct it).

Example (pseudo-MLIR):

vector.transfer_write %vector, %dest[%y, %x], %mask
  : vector<[16]x[8]xi16>, memref<?x?xi16>

Is rewritten to:

scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
  %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
    : vector<[8]xi1> from vector<[16]x[8]xi1>           |
  %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
    : vector<[8]xi16> from vector<[8]x[8]xi16>          |
  vector.transfer_write %upper_slice,                   |
    %dest[%slice_idx + %y, %x], %upper_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                  ┘
  %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
  %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
    : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
  %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
    : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
  vector.transfer_write %lower_slice,                         |
    %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                        ┘
}

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This adds a new pattern that can legalize a multi-tile transfer_write as a single store loop. This is done as part of type decomposition as at this level we know each tile write is disjoint, but that information is lost after decomposition (without analysis to reconstruct it).

Example (in pseudo-MLIR):

vector.transfer_write vector, dest[x, y], mask
  : vector&lt;[16]x[4]xf32&gt;, memref&lt;?x?xf32&gt;

Is rewritten to:

for i in range (0, 4 * vscale) {
  let sliceRow = i + tile_n.row * vscale;              ─┐
  let sliceCol = tile_n.col * vscale;                   |
  slice = vector.extract tile_n[i]                      |
    : vector&lt;[4]xf32&gt; from vector&lt;[16]x[4]xf32&gt;         |
  slice_mask = vector.extract mask[sliceRow]            |- Repeated 4x for
    : vector&lt;[4]xi1&gt; from vector&lt;[16]x[4]xi1&gt;           |  all tiles in
  vector.transfer_write                                 |  [16]x[4]
    slice, dest[x + sliceRow, y + sliceCol], slice_mask |
    : vector&lt;[4]xf32&gt;, memref&lt;?x?xf32&gt;                  ┘
}

Full diff: https://github.com/llvm/llvm-project/pull/96187.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+130-3)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+63-6)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir (+2-1)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index b595c6dd8a684..426cbfea12374 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -373,6 +374,129 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+/// Legalize a multi-tile transfer_write as a single store loop. This is done as
+/// part of type decomposition as at this level we know each tile write is
+/// disjoint, but that information is lost after decomposition (without
+/// static analysis).
+///
+/// Example (in pseudo-MLIR):
+///
+/// ```
+/// vector.transfer_write vector, dest[x, y], mask
+///   : vector<[16]x[4]xf32>, memref<?x?xf32>
+/// ```
+/// Is rewritten to:
+/// ```
+/// for i in range (0, 4 * vscale) {
+///   let sliceRow = i + tile_n.row * vscale;              ─┐
+///   let sliceCol = tile_n.col * vscale;                   |
+///   slice = vector.extract tile_n[i]                      |
+///     : vector<[4]xf32> from vector<[16]x[4]xf32>         |
+///   slice_mask = vector.extract mask[sliceRow]            |- Repeated 4x for
+///     : vector<[4]xi1> from vector<[16]x[4]xi1>           |  all tiles in
+///   vector.transfer_write                                 |  [16]x[4]
+///     slice, dest[x + sliceRow, y + sliceCol], slice_mask |
+///     : vector<[4]xf32>, memref<?x?xf32>                  ┘
+/// }
+/// ```
+struct LegalizeMultiTileTransferWriteAsStoreLoop
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (writeOp.hasPureTensorSemantics())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "tensor semantics are unsupported");
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    bool transposed = !permutationMap.isIdentity();
+    if (transposed)
+      return rewriter.notifyMatchFailure(writeOp, "transpose unsupported");
+
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
+                                              vectorType.getDimSize(1) > 16)))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto loc = writeOp.getLoc();
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto createVscaleMultiple = [&](int64_t multiplier) {
+      return rewriter.create<arith::MulIOp>(
+          loc, vscale,
+          rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+    };
+
+    // Get SME tile and slice types.
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto minTileSlices = smeTileType.getDimSize(0);
+    VectorType sliceMaskType =
+        VectorType::get(minTileSlices, rewriter.getI1Type(), true);
+
+    // Create loop over all tile slices.
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = createVscaleMultiple(minTileSlices);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto storeLoop =
+        rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    rewriter.setInsertionPointToStart(storeLoop.getBody());
+
+    // For each tile sub-tile of the multi-tile `vectorType`.
+    auto inputSMETiles = adaptor.getVector();
+    auto inductionVar = storeLoop.getInductionVar();
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
+      // The coordinates of the tile within `vectorType`.
+      auto tileRow = createVscaleMultiple(smeTile.row);
+      auto tileCol = createVscaleMultiple(smeTile.col);
+
+      // The current slice of `vectorType` we are processing.
+      auto sliceIndex =
+          rewriter.create<arith::AddIOp>(loc, tileRow, inductionVar);
+
+      // Where in the destination memref the current slice will be stored.
+      auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
+                                                     writeOp.getIndices()[0]);
+      auto storeCol =
+          rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
+
+      // Extract the mask for the current slice.
+      Value sliceMask = nullptr;
+      if (mask) {
+        sliceMask = rewriter.create<vector::ExtractOp>(
+            loc, mask, OpFoldResult(sliceIndex));
+        if (sliceMaskType != sliceMask.getType())
+          sliceMask = rewriter.create<vector::ScalableExtractOp>(
+              loc, sliceMaskType, sliceMask, smeTile.col);
+      }
+
+      // Extract and store the current slice slice.
+      Value tile = inputSMETiles[index];
+      auto slice = rewriter.create<vector::ExtractOp>(loc, tile, inductionVar);
+      rewriter.create<vector::TransferWriteOp>(
+          loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
+          AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
+          sliceMask,
+          rewriter.getBoolArrayAttr(
+              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
+    }
+
+    rewriter.eraseOp(writeOp);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ArmSME-specific fixup canonicalizations/folds
 //===----------------------------------------------------------------------===//
@@ -663,9 +787,12 @@ struct VectorLegalizationPass
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
                  ConvertIllegalShapeCastOpsToTransposes>(context);
-    // Note: High benefit to ensure masked outer products are lowered first.
-    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
-        converter, context, 1024);
+    // Note: These two patterns are added with a high benefit to ensure:
+    //  - Masked outer products are handled before unmasked ones
+    //  - Multi-tile writes are lowered as a store loop (if possible)
+    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
+                 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
+                                                            /*benefit=*/1024);
     patterns.add<LegalizeArithConstantOpsByDecomposition,
                  LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f43ef1cce787c..70d2a06797a31 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
 func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
 {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
   // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
-  // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT:   %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
@@ -201,6 +207,47 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
 
 // -----
 
+// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
+// CHECK-SAME:                                    %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:                                    %[[DIM_0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM_1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
+func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[MASK:.*]] =  vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: }
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
 #transpose = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +256,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
 func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
 {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +269,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
   // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
index ada744b322fe9..03a7d25cffa76 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -1,7 +1,8 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN:   -test-lower-to-arm-sme -test-lower-to-llvm | \
+// RUN:   -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I've only skimmed through and left a few comments inline. I'll try to take another look before I disappear for a few days. In the meantime - could you add some negative tests and updated the docs?

MacDue added 3 commits June 25, 2024 09:21
This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).

Example (in pseudo-MLIR):

```
vector.transfer_write vector, dest[x, y], mask
  : vector<[16]x[4]xf32>, memref<?x?xf32>
```
Is rewritten to:
```
for i in range (0, 4 * vscale) {
  let sliceRow = i + tile_n.row * vscale;              ─┐
  let sliceCol = tile_n.col * vscale;                   |
  slice = vector.extract tile_n[i]                      |
    : vector<[4]xf32> from vector<[16]x[4]xf32>         |
  slice_mask = vector.extract mask[sliceRow]            |- Repeated 4x for
    : vector<[4]xi1> from vector<[16]x[4]xi1>           |  all tiles in
  vector.transfer_write                                 |  [16]x[4]
    slice, dest[x + sliceRow, y + sliceCol], slice_mask |
    : vector<[4]xf32>, memref<?x?xf32>                  ┘
}
```
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one final nit otherwise LGTM cheers

@MacDue MacDue merged commit 5ed5d72 into llvm:main Jun 25, 2024
5 of 6 checks passed
@MacDue MacDue deleted the za_store_loop branch June 25, 2024 11:46
MacDue added a commit to MacDue/llvm-project that referenced this pull request Jul 3, 2024
This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).

Example (pseudo-MLIR):

```
vector.transfer_write %vector, %dest[%y, %x], %mask
  : vector<[16]x[8]xi16>, memref<?x?xi16>
```
Is rewritten to:
```
scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
  %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
    : vector<[8]xi1> from vector<[16]x[8]xi1>           |
  %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
    : vector<[8]xi16> from vector<[8]x[8]xi16>          |
  vector.transfer_write %upper_slice,                   |
    %dest[%slice_idx + %y, %x], %upper_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                  ┘
  %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
  %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
    : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
  %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
    : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
  vector.transfer_write %lower_slice,                         |
    %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                        ┘
}
```
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).

Example (pseudo-MLIR):

```
vector.transfer_write %vector, %dest[%y, %x], %mask
  : vector<[16]x[8]xi16>, memref<?x?xi16>
```
Is rewritten to:
```
scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
  %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
    : vector<[8]xi1> from vector<[16]x[8]xi1>           |
  %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
    : vector<[8]xi16> from vector<[8]x[8]xi16>          |
  vector.transfer_write %upper_slice,                   |
    %dest[%slice_idx + %y, %x], %upper_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                  ┘
  %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
  %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
    : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
  %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
    : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
  vector.transfer_write %lower_slice,                         |
    %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                        ┘
}
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants