Skip to content

[mlir][ArmSME] Add support for vector.transpose #66760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 25, 2023

Conversation

c-rhodes
Copy link
Collaborator

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on #66758.

@c-rhodes c-rhodes changed the title Arm sme vector transpose [mlir][ArmSME] Add support for vector.transpose Sep 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2023

@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Changes

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on #66758.


Patch is 110.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66760.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+5)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+79-40)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+7)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+9-9)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+83-7)
  • (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+13)
  • (modified) mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+71-22)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+31-11)
  • (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir (+3-3)
  • (added) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+401)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+450-90)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+130-8)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+110)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir (+113)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
 #include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..884773aa559bcf9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -35,7 +36,9 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0616
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
+  let useDefaultAttributePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -83,6 +86,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
                   "::llvm::cast<VectorType>($_self).getElementType())"
                   ".getWidth())">;
 
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+  I32EnumAttrCase<"Horizontal", 0, "hor">,
+  I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+                                          "layout"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -239,27 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
-    with the shape defined by the 2D scalable vector type of the result tile.
-    The slice of memory must be contiguous. The memref must be either rank 1 or
-    rank 2 with dynamic dimensions, since the operation is scalable, and the
-    element type must be a scalar that matches the element type of the result.
+    with the shape defined by the 2D scalable vector type of the result tile. A
+    tile slice layout attribute specifies whether the slices of the tile being
+    loaded are horizontal or vertical. The slice of memory must be contiguous.
+    The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+    the operation is scalable, and the element type must be a scalar that
+    matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+    Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a FP 32-bit element ZA tile from memory.
+    Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a 128-bit element ZA tile from memory.
+    Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
       Variadic<Index>:$indices);
   let results = (outs SMETile:$result);
@@ -274,7 +301,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   }];
 
   let assemblyFormat =
-      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+      "$layout `,` $base `[` $indices `]` attr-dict "
+        "`:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +310,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
     with the shape defined by the 2D scalable vector type of the tile being
-    stored. The slice of memory must be contiguous. The memref must be either
-    rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
-    and the element type must be a scalar that matches the element type of the
-    result.
+    stored. A tile slice layout attribute specifies whether the slices of the
+    tile being stored are horizontal or vertical. The slice of memory must be
+    contiguous.  The memref must be either rank 1 or rank 2 with dynamic
+    dimensions, since the operation is scalable, and the element type must be a
+    scalar that matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+    Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store a FP 32-bit element ZA tile to memory.
+    Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a 128-bit element ZA tile to memory.
+    Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -314,8 +347,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
-  let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
-                       "`:` type($base) `,` type($valueToStore)";
+  let assemblyFormat =
+    "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+      "`:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +360,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is loaded to. The updated tile is returned as the result.
+    slice is loaded to. A tile slice layout attribute specifies whether the
+    tile slice being loaded at the given index is horizontal or vertical. The
+    updated tile is returned as the result.
 
     The slice of memory read is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
-    Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+    Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+    Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+    Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from">:$base,
       SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
   let results = (outs SMETile:$result);
@@ -363,7 +400,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+    $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
       attr-dict `:` type($base) `,` type($result)
   }];
 }
@@ -374,29 +411,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
     Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is stored from.
+    slice is stored from. A tile slice layout attribute specifies whether the
+    tile slice being stored from the given index is horizontal or vertical.
 
     The slice of memory written is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the input tile.
 
-    Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+    Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+    Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+    Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -409,7 +448,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+    $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
       attr-dict `:` type($base) `,` type($tile)
   }];
 }
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
 set(LLVM_TARGET_DEFINITIONS ArmSME.td)
 mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///
 ///  BEFORE:
 ///  ```mlir
-///  %tile = arm_sme.tile_load %src[%c0, %c0] :
+///  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
 ///    memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  ```
 ///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+///    %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
 ///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
-                                              tileLoadOp.getBase(), tile,
-                                              memrefIndices, tileSliceIndex);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+        memrefIndices, tileSliceIndex);
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///  arm_sme.tile_store %tile, %dest[%c0, %c0]
+///  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
 ///  ```
 ///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-///      : memref<?x?xi32>, vector<[4]x[4]xi32>
+///    arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+///      %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), memrefIndices);
+        tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
 
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..30c516ffbe1e900 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
 
@@ -65,8 +66,8 @@ namespace {
 ///
 /// is converted to:
 ///
-///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-///                                                   vector<[16]x[16]xi8>
+///   arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +82,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+        writeOp.getSource(), writeOp.getIndices());
     return success();
   }
 };
@@ -97,7 +98,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
-        load, load.getVectorType(), load.getBase(), load.getIndices());
+        load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+        load.getBase(), load.getIndices());
 
     return success();
   }
@@ -113,7 +115,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        store, store.getValueToStore(), store.getBase(), store.getIndices());
+        store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+        store.getBase(), store.getIndices());
 
     return success();
   }
@@ -239,11 +242,84 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.transpose.
+///
+/// Stores the input tile to memory and reloads vertically.
+///
+/// Example:
+///
+///   %transposed_src = vector.transpose %src, [1, 0]
+///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
+///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load <ver>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///
+/// NOTE: Tranposing via memory is obviously expensive, the current intention
+/// is to avoid the transpose if possible, this is therefore intended as a
+/// fallback and to provide base support for Vec...
[truncated]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it would make sense to merge this with test-load-vertical.mlir and have 2 RUN lines with different entry points.

In principle, both end-to-end tests verify matrix transposition. And there's a lot of code duplication. Also, for end-to-end tests, I think that it would be nice to converge towards testing higher-level operations (e.g. matmul or conv).

If you do decide to merge it, you could keep this name to indicate that this is indeed matrix transposition.

Copy link
Collaborator Author

@c-rhodes c-rhodes Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah there is a fair amount of duplication but I'd like to keep a separate test for transpose for consistency with the generic Vector CPU tests, which has a similar test of the same name. The ArmSME Vector CPU tests have diverged slightly, particularly in the naming, I'd like to clean that up at some point.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (% some comments), thanks!

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor nit, but LGTM:

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I had reviewed this yesterday but didn't submit the comments :)

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments!

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on llvm#66758.
@c-rhodes c-rhodes force-pushed the arm-sme-vector-transpose branch from 7383332 to fed0df9 Compare September 25, 2023 09:38
@c-rhodes c-rhodes merged commit eaf1590 into llvm:main Sep 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants