Skip to content

[mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA #98620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 25, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jul 12, 2024

This adds a workaround rewrite that allows stores of unsupported SVE transposes such as:

%tr = vector.transpose %vec, [1, 0]
  : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
  : vector<[4]x2xf32>,  memref<?x?xf32>

To use SME tiles, which are possible to lower (when SME is available):

// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
   {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
   : vector<[4]x[4]xf32>, memref<?x?xf32>

@MacDue MacDue marked this pull request as ready for review July 18, 2024 19:14
@MacDue MacDue requested review from c-rhodes and removed request for banach-space, nicolasvasilache and dcaballe July 18, 2024 19:14
@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This adds a workaround rewrite that allows stores of unsupported SVE transposes such as:

%tr = vector.transpose %vec, [1, 0]
  : vector&lt;2x[4]xf32&gt; to vector&lt;[4]x2xf32&gt;
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
  : vector&lt;[4]x2xf32&gt;,  memref&lt;?x?xf32&gt;

To use SME tiles, which are possible to lower (when SME is available):

// Insert vector&lt;2x[4]xf32&gt; into an SME tile:
%0 = arm_sme.get_tile : vector&lt;[4]x[4]xf32&gt;
%1 = vector.extract %vec[0] : vector&lt;[4]xf32&gt; from vector&lt;2x[4]xf32&gt;
%2 = vector.insert %1, %0 [0] : vector&lt;[4]xf32&gt; into vector&lt;[4]x[4]xf32&gt;
%3 = vector.extract %vec[1] : vector&lt;[4]xf32&gt; from vector&lt;2x[4]xf32&gt;
%4 = vector.insert %3, %2 [1] : vector&lt;[4]xf32&gt; into vector&lt;[4]x[4]xf32&gt;
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector&lt;[4]x[4]xi1&gt;
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
   {permutation_map = affine_map&lt;(d0, d1) -&gt; (d1, d0)&gt;}
   : vector&lt;[4]x[4]xf32&gt;, memref&lt;?x?xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/98620.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+2-1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+159-11)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+71)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index dfd64f995546a..921234daad1f1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -202,7 +202,8 @@ def VectorLegalization
     "func::FuncDialect",
     "arm_sme::ArmSMEDialect",
     "vector::VectorDialect",
-    "arith::ArithDialect"
+    "arith::ArithDialect",
+    "index::IndexDialect"
   ];
 }
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 600f2ecdb51bc..8f9b5080e82db 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
+  MLIRIndexDialect
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRFuncTransforms
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 96dad6518fec8..028e2327e2a4f 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -18,6 +18,8 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -140,11 +142,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                          VectorType smeTileType,
                          bool transposeIndices = false) {
-  assert(isMultipleOfSMETileVectorType(type) &&
-         "`type` not multiple of SME tiles");
   return llvm::map_range(
-      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
-                                              smeTileType.getDimSize(1)}),
+      StaticTileOffsetRange(
+          type.getShape(),
+          {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
+           std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
       [=](auto indices) {
         int row = int(indices[0]);
         int col = int(indices[1]);
@@ -374,6 +376,14 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
+  Value vscale = rewriter.create<vector::VectorScaleOp>(loc);
+  return [loc, vscale, &rewriter](int64_t multiplier) {
+    return rewriter.create<arith::MulIOp>(
+        loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+  };
+}
+
 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
 /// part of type decomposition as at this level we know each tile write is
 /// disjoint, but that information is lost after decomposition (without analysis
@@ -440,12 +450,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
                                          kMatchFailureUnsupportedMaskOp);
 
     auto loc = writeOp.getLoc();
-    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
-    auto createVscaleMultiple = [&](int64_t multiplier) {
-      return rewriter.create<arith::MulIOp>(
-          loc, vscale,
-          rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
-    };
+    auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
 
     // Get SME tile and slice types.
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -775,6 +780,148 @@ struct ConvertIllegalShapeCastOpsToTransposes
   }
 };
 
+/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
+/// the ZA state. This workaround rewrite to support these transposes when ZA is
+/// available.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %transpose = vector.transpose %vec, [1, 0]
+///     : vector<2x[4]xf32> to vector<[4]x2xf32>
+///  vector.transfer_write %transpose, %dest[%y, %x]
+///     : vector<[4]x2xf32>,  memref<?x?xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
+///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
+///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+///   %c4_vscale = arith.muli %vscale, %c4 : index
+///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
+///   vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
+///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+///      : vector<[4]x[4]xf32>, memref<?x?xf32>
+///  ```
+///
+/// Values larger than a single tile are supported via decomposition.
+struct LowerIllegalTransposeStoreViaZA
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!isSupportedMaskOp(writeOp.getMask()))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isIdentity())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return failure();
+
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+
+    if (resultType.getRank() != 2)
+      return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
+
+    if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "not illegal/unsupported SVE transpose");
+
+    auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
+    VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+    if (sourceType.getDimSize(0) <= 1 ||
+        sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
+      return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
+
+    auto loc = writeOp.getLoc();
+    auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+
+    auto transposeMap = AffineMapAttr::get(
+        AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
+
+    // Note: We need to use `get_tile` as there's no vector-level `undef`.
+    Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+    Value destTensorOrMemref = writeOp.getSource();
+    auto numSlicesPerTile =
+        std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
+    auto numSlices =
+        rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
+      // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
+      // of slices from the source type into the SME tile. Without checking
+      // vscale (and emitting multiple implementations) we can't make use of the
+      // rows of the tile after 1*vscale rows.
+      Value tile = undefTile;
+      for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
+        Value vector = rewriter.create<vector::ExtractOp>(
+            loc, transposeOp.getVector(),
+            rewriter.getIndexAttr(d + smeTile.row));
+        if (vector.getType() != smeSliceType) {
+          vector = rewriter.create<vector::ScalableExtractOp>(
+              loc, smeSliceType, vector, smeTile.col);
+        }
+        tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
+      }
+
+      // 2. Transpose the tile position.
+      auto transposedRow = createVscaleMultiple(smeTile.col);
+      auto transposedCol =
+          rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
+
+      // 3. Compute mask for tile store.
+      Value maskRows;
+      Value maskCols;
+      if (auto mask = writeOp.getMask()) {
+        auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+        maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
+                                                  transposedRow);
+        maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
+                                                  transposedCol);
+        maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
+      } else {
+        maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
+        maskCols = numSlices;
+      }
+      auto subMask = rewriter.create<vector::CreateMaskOp>(
+          loc, smeTileType.clone(rewriter.getI1Type()),
+          ValueRange{maskRows, maskCols});
+
+      // 4. Emit a transposed tile write.
+      auto writeIndices = writeOp.getIndices();
+      Value destRow =
+          rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
+      Value destCol =
+          rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
+      auto smeWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
+          transposeMap, subMask, writeOp.getInBounds().value_or(ArrayAttr{}));
+
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = smeWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -796,7 +943,8 @@ struct VectorLegalizationPass
 
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
-                 ConvertIllegalShapeCastOpsToTransposes>(context);
+                 ConvertIllegalShapeCastOpsToTransposes,
+                 LowerIllegalTransposeStoreViaZA>(context);
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 71d80bc16ea12..951b29b6e3805 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -544,3 +544,74 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
   %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
   return %0 : vector<[8]x[8]xi32>
 }
+
+// -----
+
+// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_store_scalable_via_za(
+// CHECK-SAME:                                   %[[VEC:.*]]: vector<2x[4]xf32>
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                   %[[I:.*]]: index,
+// CHECK-SAME:                                   %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+  // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+  // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
+  // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>,  memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_masked(
+// CHECK-SAME:                                    %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[B:[a-z0-9]+]]: index)
+func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
+  // CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
+  // CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %a, %b : vector<[4]x2xi1>
+  %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+  vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>,  memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_multi_tile(
+// CHECK-SAME:                                       %[[VEC:.*]]: vector<8x[4]xf32>
+// CHECK-SAME:                                       %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                       %[[I:.*]]: index,
+// CHECK-SAME:                                       %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[VSCALE:.*]] = vector.vscale
+
+  // <skip 3x other extract+insert chain>
+  // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
+  // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
+  // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+
+  // <skip 3x other extract+insert chain>
+  // CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
+  // CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
+  // CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  %tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
+  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>,  memref<?x?xf32>
+  return
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have one question, but you can answer after landing. Thanks, LGTM!

/// : vector<[4]x[4]xf32>, memref<?x?xf32>
/// ```
///
/// Values larger than a single tile are supported via decomposition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥳

rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these dims dropped?

Copy link
Member Author

@MacDue MacDue Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't dynamically index the array-of-vectors input (or dynamically select an SME tile). These restrictions mean this lowering just targets the lowest common denominator (that is vscale = 1).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, IIUC the scalability is dropped on L853. And, basically, this pattern will store at most 4 rows per tile?

MacDue added 5 commits July 25, 2024 15:26
…ME/ZA

This adds a workaround rewrite that allows stores of unsupported SVE
transposes such as:

```mlir
%tr = vector.transpose %vec, [1, 0]
  : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
  : vector<[4]x2xf32>,  memref<?x?xf32>
```

To use SME tiles, which are possible to lower (when SME is available):

```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
   {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
   : vector<[4]x[4]xf32>, memref<?x?xf32>
```
@MacDue MacDue force-pushed the transpose_workaround branch from cc09972 to a37d6da Compare July 25, 2024 15:37
@MacDue MacDue merged commit c194bc7 into llvm:main Jul 25, 2024
7 checks passed
@MacDue MacDue deleted the transpose_workaround branch July 25, 2024 17:15
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…ME/ZA (#98620)

Summary:
This adds a workaround rewrite that allows stores of unsupported SVE
transposes such as:

```mlir
%tr = vector.transpose %vec, [1, 0]
  : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
  : vector<[4]x2xf32>,  memref<?x?xf32>
```

To use SME tiles, which are possible to lower (when SME is available):

```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
   {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
   : vector<[4]x[4]xf32>, memref<?x?xf32>
```

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250566
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants