Skip to content

[mlir][ArmSME] Switch to an attribute-based tile allocation scheme #73253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 30, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 23, 2023

This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:

  • Tile allocation can now be done ASAP (i.e. immediately after -convert-vector-to-arm-sme)
  • SSA form for control flow is now supported (e.g.scf.for loops that yield tiles)
  • ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow)
  • Tests are simplified by removing constants and casts
  • Avoids correctness issues with representing LLVM immargs as MLIR values
    • The tile ID on the SME intrinsics is an immarg (so is required to be a compile-time constant), immargs should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect)
    • Using MLIR values for immargs can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32

These are now replaced with:

// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>

The new tile allocation works by operations implementing the ArmSMETileOpInterface. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...

std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()

...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return std::nullopt (which is the default behaviour).

Currently the following ops are defined as allocating:

arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)

Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases.

Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly created operations.

@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2023

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:

  • Tile allocation can now be done ASAP (i.e. immediately after -convert-vector-to-arm-sme)
  • SSA form for control flow is now supported (e.g.scf.for loops that yeild tiles)
  • ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow)
  • Tests are simplified by removing constants and casts
  • Avoids correctness issues with representing LLVM immargs as MLIR values
    • The tile ID on the SME intrinsics is an immarg (so is required to be a compile-time constant), immargs should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect)
    • Using MLIR values for immargs can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector&lt;[4]x[4]xi32&gt;
arm_sme.cast_vector_to_tile : vector&lt;[4]x[4]xi32&gt; to i32

These are now replaced with:

// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector&lt;[4]x[4]xi32&gt;
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector&lt;[4]x[4]xi32&gt;

The new tile allocation works by operations implementing the ArmSMETileOpInterface. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...

std::optional&lt;arm_sme::ArmSMETileType&gt; getAllocatedTileType()

...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return std::nullopt (which is the default behaviour).

Currently the following ops are defined as allocating:

arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)

Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases.

Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly operations.


Patch is 279.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73253.diff

41 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+3-1)
  • (added) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h (+16)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+38-14)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+152-131)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+6)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+6-10)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+99-98)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+19-28)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+3-20)
  • (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+2-20)
  • (modified) mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+60-45)
  • (modified) mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt (-2)
  • (modified) mlir/lib/Dialect/ArmSME/Utils/Utils.cpp (+24-20)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+4-6)
  • (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+2-4)
  • (removed) mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir (-51)
  • (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+147-105)
  • (modified) mlir/test/Dialect/ArmSME/canonicalize.mlir (+19-21)
  • (modified) mlir/test/Dialect/ArmSME/cse.mlir (+25-11)
  • (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+11-51)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+37-156)
  • (modified) mlir/test/Dialect/ArmSME/tile-allocation.mlir (+185-191)
  • (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+16-86)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+200-210)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+2-4)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir (+2-3)
  • (modified) mlir/test/Target/LLVMIR/arm-sme-invalid.mlir (+6-9)
  • (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+161-170)
  • (modified) mlir/tools/mlir-query/mlir-query.cpp (+3-3)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..1da8e488a4c4647 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..430f3571001c8f4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEDialect.h - Arm SME Dialect Enums ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..2a0167afa8bae9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
   }];
 }
 
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = [],
+                    list<int> overloadedOperands = [],
                     list<Trait> traits = [], int numResults = 0,
                     list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
@@ -64,16 +67,26 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
           /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
+          /*int numResults=*/numResults,
+          /*bit requiresAccessGroup=*/0,
+          /*bit requiresAliasAnalysis=*/0,
+          /*bit requiresFastmath=*/0,
+          /*list<int> immArgPositions=*/immArgPositions,
+          /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 // Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero",
+                            /*immArgPositions=*/[0],
+                            /*immArgAttrNames=*/["tile_mask"]>,
+                            Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
 
 // MOP's
 class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[0],
+      /*immArgAttrNames=*/["tile_id"],
+      /*overloadedOperands=*/[4]>,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
                  Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
                  Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +105,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[2],
+      /*immArgAttrNames=*/["tile_id"]>;
+
 // Loads
 class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Load address">:$load_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +131,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
 
 // Stores
 class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +156,28 @@ def LLVM_aarch64_sme_str
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+    : ArmSME_IntrOp<"write." # direction,
+                    /*immArgPositions=*/[0],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[3],
                     [AllShapesMatch<["predicate", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
 // Tile slice to vector
 class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+    : ArmSME_IntrOp<"read." # direction,
+                    /*immArgPositions=*/[2],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[],
                     [AllShapesMatch<["vector", "predicate", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
                     /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
-                     Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..abcc9b649c4a530 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,99 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+    [
+      I32EnumAttrCase<"ZAB", 0, "za.b">,
+      I32EnumAttrCase<"ZAH", 1, "za.h">,
+      I32EnumAttrCase<"ZAS", 2, "za.s">,
+      I32EnumAttrCase<"ZAD", 3, "za.d">,
+      I32EnumAttrCase<"ZAQ", 4, "za.q">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+  let description = [{
+    An interface for operations that use or allocate Arm SME tiles. These
+    operations need to be assigned a tile ID an i32 attribute, which specifies
+    which virtual tile within the ZA storage to use. The number of tiles
+    available depends on the type of the tile. This is summarized below:
+
+    | Tile Vector Types                                                       | Possible Tile IDs   |
+    |-------------------------------------------------------------------------|---------------------|
+    | `vector<[16]x[16]xi8>`                                                  | 0                   |
+    | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1             |
+    | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
+    | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
+    | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
+
+    Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+    the roots for tile allocation, with all operations that (transitively)
+    depend on a root being assigned the same tile ID.
+  }];
+  let methods = [
+    InterfaceMethod<
+      "Sets the tile ID for this operation.",
+      /*returnType=*/"void",
+      /*methodName=*/"setTileId",
+      /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        if (!tileId)
+          return;
+        ::mlir::Operation* op = this->getOperation();
+        op->setAttr("tile_id", tileId);
+      }]
+    >,
+    InterfaceMethod<
+      "Returns the (possibly null) tile ID assigned to this operation.",
+      /*returnType=*/"mlir::IntegerAttr",
+      /*methodName=*/"getTileId",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        ::mlir::Operation* op = this->getOperation();
+        return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+      }]
+    >,
+    InterfaceMethod<
+      "The type of tile this operation allocates (or none)",
+      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+      /*methodName=*/"getAllocatedTileType",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        // Do not allocate a new tile.
+        return std::nullopt;
+      }]
+    >
+  ];
+
+  let extraSharedClassDeclaration = [{
+    // A helper to create a new operation and propagate this operations tile ID.
+    template<typename T, typename... Args>
+    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+        tileOp.setTileId($_op.getTileId());
+      return op;
+    }
+
+    // A helper to replace this operation and forward any tile ID.
+    template<typename T, typename... Args>
+    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+      rewriter.replaceOp($_op, newOp);
+      return newOp;
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -44,7 +137,8 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
-                        "a vector type that fits into a SME tile">
+                        "a vector type that fits into a SME tile",
+                        "VectorType">
 {
   let description = [{
     Possible vector types:
@@ -66,40 +160,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
   }];
 }
 
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
-                       "an identifier of a virtual tile (of a size) within the ZA storage">
-{
-  let description = [{
-    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
-    the integer indicates the tile to use, and the bit size indicates the size
-    of tile. The number of tiles available and the element types of those depend
-    on the size. This is summarised below:
-
-    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
-    |--------------|---------------------|-------------------------------------------------------------------------|
-    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
-    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
-    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
-    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
-    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
-  }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
-  "`tile_id` has the same number of bits as elements in `vector`",
-  "vector", "tile_id",
-  "IntegerType::get("
-      "$_self.getContext(),"
-      "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
-          "? ::llvm::cast<IntegerType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth()"
-          ": ::llvm::cast<FloatType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth())">;
-
 class HasMatchingMaskTypeConstraint<string vector, string mask> :
   OptionalTypesMatchWith<
     mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +222,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+  let summary = "Returns a SME virtual tile";
   let description = [{
-    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
-    scalable vector type, which represents an SME "virtual tile". This would
-    normally be used when lowering operations that return "virtual tile" vector
-    types to model the output. This is required to preserve dataflow as SME
-    intrinsics have no return values.
+    Allocates a new SME "virtual tile" within a function. The contents of the
+    tile returned from this operation undefined.
 
-    Example:
+    Example 1:
 
-    Input:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 8-bit element "virtual tile"
+    %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
-    After lowering `vector.load`:
+    Example 2:
+
     ```mlir
-    %tile_id = arm_sme.get_tile_id : i32
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-    }
-    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate two 16-bit element "virtual tiles"
+    %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+    %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
-    In the example above, the `vector.load` can't be replaced with an SME
-    intrinsic that has no outputs since it is used by the `vector.store`.
-    However, by inserting a `cast_tile_to_vector` op after the load intrinsics
-    the `vector.load` can be replaced. This enables "local" rewrites on
-    individual vector ops, rather than "global" rewrites that would have to
-    look at the vector op uses and also lower them.
-
-    Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
-    the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
-  }];
-  let arguments = (ins TileID:$tile_id);
-  let results = (outs SMETile:$vector);
-  let assemblyFormat =
-    "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
-  let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from 2-d scalable vector type to tile id";
-  let description = [{
-    A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
-    type, which represents an SME "virtual tile", to a tile id. This is
-    required to preserve dataflow as the SME intrinsics have no return values.
-
-    Example:
-
-    Input:
+    Example 3:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 128-bit element "virtual tile"
+    %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
+  }];
 
-    After lowering `vector.store`:
-    ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
-      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
+
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
     }
-    ```
 
-    Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
-    the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getTileType());
+    }
   }];
-  let arguments = (ins SMETile:$vector);
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat =
-    "$vector attr-dict `:` type($vector) `to` type($tile_id)";
-  let hasCanonicalizeMethod = 1;
 }
 
-def GetTileID : ArmSME_Op<"get_tile_id"> {
-  let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+  let summary = "SME tile placeholder";
   let description = [{
-    A `get_tile_id` operation returns a scalar integer representing an SME
-    "virtual tile" id. The bitwidth of the scalar indicates the element
-    bitwidth of the "virtual tile".
-
-    The scope of a tile id is a function and cannot be passed or returned from
-    functions.
+    A placeholder to preserve dataflow while lowering to SME intrinsics (which
+    do not take or return tile values). This operation is intended to be DCE'd
+    once all ArmSME operations have been lowered.
 
-    Example:
-    ```mlir
-    // Allocate and return an 8-bit element "virtual tile" id
-    %za0_b = arm_sme.get_tile_id : i8
-    ```
-
-    Example:
-    ```
-    // Allocate and return two 16-bit element "virtual tile" ids
-    %za0_h = arm_sme.get_tile_id : i16
-    %za1_h = arm_sme.get_tile_id : i16
-    ```
-
-    Example:
-    ```
-    // Allocate and return an 128-bit element "virtual tile" id
-    %za0_q = arm_sme.get_tile_id : i128
-  ...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! There's quite a bit of churn, but sadly that's unavoidable.

I've left a few minor comments - nothing major. I'd like to take another look though - this is quite dense 😅 .

Comment on lines 228 to 229
Allocates a new SME "virtual tile" within a function. The contents of the
tile returned from this operation undefined.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "within a function"?

Also:

Suggested change
Allocates a new SME "virtual tile" within a function. The contents of the
tile returned from this operation undefined.
Allocates a new SME "virtual tile" within a function. The contents of the
tile returned from this operation are undefined.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiles are only allocated with respect to a function right now, but they don't necessarily need to be, MLIR is pretty high-level in SCF, so tiles could be allocated in scopes (like in C).. which could be a something that's useful.

@MacDue MacDue force-pushed the arm_sme_attr_allocation branch from 6bbd4d7 to e3b0e3d Compare November 24, 2023 11:46
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work Ben, I've left a few minor questions / comments but otherwise LGTM, really polished

auto tileMask = rewriter.create<arith::ShLIOp>(
loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue force-pushed the arm_sme_attr_allocation branch from c0fba6f to 77f414a Compare November 27, 2023 17:09
MacDue added a commit to MacDue/llvm-project that referenced this pull request Nov 27, 2023
Since llvm#73253 we now loops in SSA form for tiles (i.e. loops that take
`iter_args` and yield a new tile), so this patch updates lowerings to
use that. This is a NFC, as it still lowers to the same intrinsics,
but this makes IR less 'surprising' at a higher-level, and may be
recognised by more transforms.
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits. I'd like to take another look at tile allocation - just to better understand the limitations. The overall design makes a lot of sense.

This reworks the ArmSME dialect to use attributes for tile allocation.
This has a number of advantages and corrects some issues with the
previous approach:

 * Tile allocation can now be done ASAP (i.e. immediately after
   `-convert-vector-to-arm-sme`)
 * SSA form for control flow is now supported (e.g.`scf.for` loops that
   yeild tiles)
 * ArmSME ops can be converted to intrinsics very late (i.e. after
   lowering to control flow)
 * Tests are simplified by removing constants and casts
 * Avoids correctness issues with representing LLVM `immargs` as MLIR
   values
    - The tile ID on the SME intrinsics is an `immarg` (so is required
      to be a compile-time constant), `immargs` should be mapped to
      MLIR attributes (this is already the case for intrinsics in the
      LLVM dialect)
    - Using MLIR values for `immargs` can lead to invalid LLVM IR being
      generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```

These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```

The new tile allocation works by operations implementing the
`ArmSMETileOpInterface`. This interface says that an operation needs
to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return `std::nullopt` (which
is the default behaviour).

Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```

Allocating operations become the roots for the tile allocation pass,
which currently just (naively) assigns all transitive uses of a root
operation the same tile ID. However, this is enough to handle current
use cases.

Once tile IDs have been allocated subsequent rewrites can forward the
tile IDs to any newly operations.
This function follows uses of a value through control flow. It understands
basic SCF contructs and more generally works on control flow branches.

(the previous slice analysis is very basic and does not understand any
control flow)
@MacDue MacDue force-pushed the arm_sme_attr_allocation branch from 784df8a to 63f3b0f Compare November 29, 2023 13:01
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes a lot of sense, LGTM :)

It's been a huge effort - thank you for working on this! This change is a significant improvement the SME support in MLIR 🙏🏻 🙇🏻 .

(please wait for +1 from @joker-eph before landing this)

@joker-eph
Copy link
Collaborator

(please wait for +1 from @joker-eph before landing this)

I didn't look deeper than the high-level drive-by comments I made before, so please don't wait for me :)

@MacDue MacDue merged commit eaff02f into llvm:main Nov 30, 2023
@MacDue MacDue deleted the arm_sme_attr_allocation branch November 30, 2023 10:22
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! This is a big improvements! Thanks!
I have just one design comment.

bitwidth of the "virtual tile".
A placeholder to preserve dataflow while lowering to SME intrinsics (which
do not take or return SME virtual tile values). This operation is intended
to be DCE'd once all ArmSME operations have been lowered.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like we are using the IR to pass information from between patterns of the conversion (i.e., we are to some extent materializing a state of a transformation in the IR). Did you think about create an actual C++ state and keep that information there instead of materializing it during the IR? Am I missing something?

Copy link
Member Author

@MacDue MacDue Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used to pass state, it's just used to replace arm_sme ops (that take and return tile values), with arm_sme intrinsics that only take a tile_id. The point is after all arm_sme ops have been lowered to intrinsics the arm_sme.materialize_ssa_tile ops (which don't do anything other than act as a placeholder to allow incremental rewrites, like unrealized conversion casts), become dead code and fold away.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This idea is shown in: https://github.com/MacDue/llvm-project/blob/b9064adf0d6a53c61a907fc2e61f51997a651eab/mlir/test/Dialect/ArmSME/canonicalize.mlir (where the arm_sme.materialize_ssa_tile op and the unused tile SSA values fold way).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Better naming suggestions here are welcome :))

MacDue added a commit to MacDue/llvm-project that referenced this pull request Dec 6, 2023
Since llvm#73253 we now loops in SSA form for tiles (i.e. loops that take
`iter_args` and yield a new tile), so this patch updates lowerings to
use that. This is a NFC, as it still lowers to the same intrinsics,
but this makes IR less 'surprising' at a higher-level, and may be
recognised by more transforms.
MacDue added a commit that referenced this pull request Dec 6, 2023
…gs (#73922)

Since #73253, loops over tiles in SSA form (i.e. loops that take
`iter_args` and yield a new tile) are supported, so this patch updates
ArmSME lowerings to this form. This is a NFC, as it still lowers to the
same intrinsics, but this makes IR less 'surprising' at a higher-level,
and may be recognised by more transforms.


Example:

IR before:
```mlir
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 
{
   arm_sme.move_vector_to_tile_slice 
     %broadcast_to_1d, %tile, %tile_slice_index : 
     vector<[4]xi32> into vector<[4]x[4]xi32>
}
// ... later use %tile
```
IR now:
```mlir
%broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
    step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
{
   %tile_update = arm_sme.move_vector_to_tile_slice
      %broadcast_to_1d, %iter_tile, %tile_slice_index :
      vector<[4]xi32> into vector<[4]x[4]xi32>
  scf.yield %tile_update : vector<[4]x[4]xi32>
}
// ... later use %broadcast_to_tile
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants