-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Switch to an attribute-based tile allocation scheme #73253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThis reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:
As part of this patch we bid farewell to the following operations: arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32 These are now replaced with: // Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32> The new tile allocation works by operations implementing the Operations allocate a new tile by implementing... std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() ...and returning what type of tile the op allocates (ZAB, ZAH, etc). Operations that don't allocate a tile return Currently the following ops are defined as allocating: arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified) Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases. Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly operations. Patch is 279.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73253.diff 41 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..1da8e488a4c4647 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..430f3571001c8f4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEDialect.h - Arm SME Dialect Enums ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..2a0167afa8bae9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
}];
}
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<int> overloadedOperands = [],
list<Trait> traits = [], int numResults = 0,
list<int> overloadedResults = []>
: LLVM_IntrOpBase<
@@ -64,16 +67,26 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
// Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
- Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero",
+ /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["tile_mask"]>,
+ Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
// MOP's
class ArmSME_IntrMopOverloadedOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic, [4]>,
- Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+ : ArmSME_IntrOp<mnemonic,
+ /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["tile_id"],
+ /*overloadedOperands=*/[4]>,
+ Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +105,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+ : ArmSME_IntrOp<mnemonic,
+ /*immArgPositions=*/[2],
+ /*immArgAttrNames=*/["tile_id"]>;
+
// Loads
class ArmSME_IntrLoadOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic>,
+ : ArmSME_IntrLoadStoreOp<mnemonic>,
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<LLVM_AnyPointer, "Load address">:$load_address,
- Arg<I32, "Virtual tile ID">:$tile_id,
+ Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;
def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +131,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
// Stores
class ArmSME_IntrStoreOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic>,
+ : ArmSME_IntrLoadStoreOp<mnemonic>,
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
- Arg<I32, "Virtual tile ID">:$tile_id,
+ Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;
def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +156,28 @@ def LLVM_aarch64_sme_str
// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
- : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+ : ArmSME_IntrOp<"write." # direction,
+ /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["tile_id"],
+ /*overloadedOperands=*/[3],
[AllShapesMatch<["predicate", "vector"]>]>,
- Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+ Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index,
Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<SVEVector, "Vector operand">:$vector)>;
// Tile slice to vector
class LLVM_aarch64_sme_read<string direction>
- : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ : ArmSME_IntrOp<"read." # direction,
+ /*immArgPositions=*/[2],
+ /*immArgAttrNames=*/["tile_id"],
+ /*overloadedOperands=*/[],
[AllShapesMatch<["vector", "predicate", "res"]>,
AllElementTypesMatch<["vector", "res"]>],
/*numResults=*/1, /*overloadedResults=*/[0]>,
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
Arg<SVEPredicate, "Vector predicate">:$predicate,
- Arg<I32, "Virtual tile ID">:$tile_id,
+ Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..abcc9b649c4a530 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,99 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+ [
+ I32EnumAttrCase<"ZAB", 0, "za.b">,
+ I32EnumAttrCase<"ZAH", 1, "za.h">,
+ I32EnumAttrCase<"ZAS", 2, "za.s">,
+ I32EnumAttrCase<"ZAD", 3, "za.d">,
+ I32EnumAttrCase<"ZAQ", 4, "za.q">,
+ ]>{
+ let cppNamespace = "mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+ let description = [{
+ An interface for operations that use or allocate Arm SME tiles. These
+ operations need to be assigned a tile ID an i32 attribute, which specifies
+ which virtual tile within the ZA storage to use. The number of tiles
+ available depends on the type of the tile. This is summarized below:
+
+ | Tile Vector Types | Possible Tile IDs |
+ |-------------------------------------------------------------------------|---------------------|
+ | `vector<[16]x[16]xi8>` | 0 |
+ | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1 |
+ | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
+ | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
+ | `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
+
+ Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+ the roots for tile allocation, with all operations that (transitively)
+ depend on a root being assigned the same tile ID.
+ }];
+ let methods = [
+ InterfaceMethod<
+ "Sets the tile ID for this operation.",
+ /*returnType=*/"void",
+ /*methodName=*/"setTileId",
+ /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/ [{
+ if (!tileId)
+ return;
+ ::mlir::Operation* op = this->getOperation();
+ op->setAttr("tile_id", tileId);
+ }]
+ >,
+ InterfaceMethod<
+ "Returns the (possibly null) tile ID assigned to this operation.",
+ /*returnType=*/"mlir::IntegerAttr",
+ /*methodName=*/"getTileId",
+ /*arguments=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/ [{
+ ::mlir::Operation* op = this->getOperation();
+ return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+ }]
+ >,
+ InterfaceMethod<
+ "The type of tile this operation allocates (or none)",
+ /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+ /*methodName=*/"getAllocatedTileType",
+ /*arguments=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/ [{
+ // Do not allocate a new tile.
+ return std::nullopt;
+ }]
+ >
+ ];
+
+ let extraSharedClassDeclaration = [{
+ // A helper to create a new operation and propagate this operations tile ID.
+ template<typename T, typename... Args>
+ T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+ auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+ if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+ tileOp.setTileId($_op.getTileId());
+ return op;
+ }
+
+ // A helper to replace this operation and forward any tile ID.
+ template<typename T, typename... Args>
+ T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+ auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+ rewriter.replaceOp($_op, newOp);
+ return newOp;
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ArmSME type definitions
//===----------------------------------------------------------------------===//
@@ -44,7 +137,8 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
- "a vector type that fits into a SME tile">
+ "a vector type that fits into a SME tile",
+ "VectorType">
{
let description = [{
Possible vector types:
@@ -66,40 +160,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
}];
}
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
- "an identifier of a virtual tile (of a size) within the ZA storage">
-{
- let description = [{
- The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
- the integer indicates the tile to use, and the bit size indicates the size
- of tile. The number of tiles available and the element types of those depend
- on the size. This is summarised below:
-
- | Tile ID Type | Possible Tile IDs | Tile Vector Types |
- |--------------|---------------------|-------------------------------------------------------------------------|
- | `i8` | 0 | `vector<[16]x[16]xi8>` |
- | `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
- | `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
- | `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
- | `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
- }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
- "`tile_id` has the same number of bits as elements in `vector`",
- "vector", "tile_id",
- "IntegerType::get("
- "$_self.getContext(),"
- "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
- "? ::llvm::cast<IntegerType>("
- "::llvm::cast<VectorType>($_self).getElementType())"
- ".getWidth()"
- ": ::llvm::cast<FloatType>("
- "::llvm::cast<VectorType>($_self).getElementType())"
- ".getWidth())">;
-
class HasMatchingMaskTypeConstraint<string vector, string mask> :
OptionalTypesMatchWith<
mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +222,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
Op<ArmSME_Dialect, mnemonic, traits> {}
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
- let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+ let summary = "Returns a SME virtual tile";
let description = [{
- A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
- scalable vector type, which represents an SME "virtual tile". This would
- normally be used when lowering operations that return "virtual tile" vector
- types to model the output. This is required to preserve dataflow as SME
- intrinsics have no return values.
+ Allocates a new SME "virtual tile" within a function. The contents of the
+ tile returned from this operation undefined.
- Example:
+ Example 1:
- Input:
```mlir
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
- vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // Allocate an 8-bit element "virtual tile"
+ %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
```
- After lowering `vector.load`:
+ Example 2:
+
```mlir
- %tile_id = arm_sme.get_tile_id : i32
- scf.for %vnum = %c0 to %num_vectors step %c1 {
- // ...
- "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
- }
- %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
- vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // Allocate two 16-bit element "virtual tiles"
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
```
- In the example above, the `vector.load` can't be replaced with an SME
- intrinsic that has no outputs since it is used by the `vector.store`.
- However, by inserting a `cast_tile_to_vector` op after the load intrinsics
- the `vector.load` can be replaced. This enables "local" rewrites on
- individual vector ops, rather than "global" rewrites that would have to
- look at the vector op uses and also lower them.
-
- Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
- the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
- }];
- let arguments = (ins TileID:$tile_id);
- let results = (outs SMETile:$vector);
- let assemblyFormat =
- "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
- let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
- let summary = "Cast from 2-d scalable vector type to tile id";
- let description = [{
- A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
- type, which represents an SME "virtual tile", to a tile id. This is
- required to preserve dataflow as the SME intrinsics have no return values.
-
- Example:
-
- Input:
+ Example 3:
```mlir
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
- vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // Allocate an 128-bit element "virtual tile"
+ %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
```
+ }];
- After lowering `vector.store`:
- ```mlir
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
- scf.for %vnum = %c0 to %num_vectors step %c1 {
- // ...
- %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
- "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ let results = (outs SMETile:$tile);
+ let assemblyFormat = "attr-dict `:` type($tile)";
+
+ let extraClassDeclaration = [{
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getTile().getType());
}
- ```
- Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
- the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ return arm_sme::getSMETileType(getTileType());
+ }
}];
- let arguments = (ins SMETile:$vector);
- let results = (outs TileID:$tile_id);
- let assemblyFormat =
- "$vector attr-dict `:` type($vector) `to` type($tile_id)";
- let hasCanonicalizeMethod = 1;
}
-def GetTileID : ArmSME_Op<"get_tile_id"> {
- let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+ let summary = "SME tile placeholder";
let description = [{
- A `get_tile_id` operation returns a scalar integer representing an SME
- "virtual tile" id. The bitwidth of the scalar indicates the element
- bitwidth of the "virtual tile".
-
- The scope of a tile id is a function and cannot be passed or returned from
- functions.
+ A placeholder to preserve dataflow while lowering to SME intrinsics (which
+ do not take or return tile values). This operation is intended to be DCE'd
+ once all ArmSME operations have been lowered.
- Example:
- ```mlir
- // Allocate and return an 8-bit element "virtual tile" id
- %za0_b = arm_sme.get_tile_id : i8
- ```
-
- Example:
- ```
- // Allocate and return two 16-bit element "virtual tile" ids
- %za0_h = arm_sme.get_tile_id : i16
- %za1_h = arm_sme.get_tile_id : i16
- ```
-
- Example:
- ```
- // Allocate and return an 128-bit element "virtual tile" id
- %za0_q = arm_sme.get_tile_id : i128
- ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks! There's quite a bit of churn, but sadly that's unavoidable.
I've left a few minor comments - nothing major. I'd like to take another look though - this is quite dense 😅 .
Allocates a new SME "virtual tile" within a function. The contents of the | ||
tile returned from this operation undefined. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "within a function"?
Also:
Allocates a new SME "virtual tile" within a function. The contents of the | |
tile returned from this operation undefined. | |
Allocates a new SME "virtual tile" within a function. The contents of the | |
tile returned from this operation are undefined. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiles are only allocated with respect to a function right now, but they don't necessarily need to be, MLIR is pretty high-level in SCF, so tiles could be allocated in scopes (like in C).. which could be a something that's useful.
6bbd4d7
to
e3b0e3d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work Ben, I've left a few minor questions / comments but otherwise LGTM, really polished
auto tileMask = rewriter.create<arith::ShLIOp>( | ||
loc, baseMask, castTileIDToI32(tileId, loc, rewriter)); | ||
rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask); | ||
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
c0fba6f
to
77f414a
Compare
Since llvm#73253 we now loops in SSA form for tiles (i.e. loops that take `iter_args` and yield a new tile), so this patch updates lowerings to use that. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nits. I'd like to take another look at tile allocation - just to better understand the limitations. The overall design makes a lot of sense.
This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach: * Tile allocation can now be done ASAP (i.e. immediately after `-convert-vector-to-arm-sme`) * SSA form for control flow is now supported (e.g.`scf.for` loops that yeild tiles) * ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow) * Tests are simplified by removing constants and casts * Avoids correctness issues with representing LLVM `immargs` as MLIR values - The tile ID on the SME intrinsics is an `immarg` (so is required to be a compile-time constant), `immargs` should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect) - Using MLIR values for `immargs` can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations) As part of this patch we bid farewell to the following operations: ```mlir arm_sme.get_tile_id : i32 arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32> arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32 ``` These are now replaced with: ```mlir // Allocates a new tile with (indeterminate) state: arm_sme.get_tile : vector<[4]x[4]xi32> // A placeholder operation for lowering ArmSME ops to intrinsics: arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32> ``` The new tile allocation works by operations implementing the `ArmSMETileOpInterface`. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile. Operations allocate a new tile by implementing... ```c++ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() ``` ...and returning what type of tile the op allocates (ZAB, ZAH, etc). Operations that don't allocate a tile return `std::nullopt` (which is the default behaviour). Currently the following ops are defined as allocating: ```mlir arm_sme.get_tile arm_sme.zero arm_sme.tile_load arm_sme.outerproduct // (if no accumulator is specified) ``` Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases. Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly operations.
This function follows uses of a value through control flow. It understands basic SCF contructs and more generally works on control flow branches. (the previous slice analysis is very basic and does not understand any control flow)
784df8a
to
63f3b0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes a lot of sense, LGTM :)
It's been a huge effort - thank you for working on this! This change is a significant improvement the SME support in MLIR 🙏🏻 🙇🏻 .
(please wait for +1 from @joker-eph before landing this)
I didn't look deeper than the high-level drive-by comments I made before, so please don't wait for me :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! This is a big improvements! Thanks!
I have just one design comment.
bitwidth of the "virtual tile". | ||
A placeholder to preserve dataflow while lowering to SME intrinsics (which | ||
do not take or return SME virtual tile values). This operation is intended | ||
to be DCE'd once all ArmSME operations have been lowered. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds like we are using the IR to pass information from between patterns of the conversion (i.e., we are to some extent materializing a state of a transformation in the IR). Did you think about create an actual C++ state and keep that information there instead of materializing it during the IR? Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used to pass state, it's just used to replace arm_sme ops (that take and return tile values), with arm_sme intrinsics that only take a tile_id. The point is after all arm_sme ops have been lowered to intrinsics the arm_sme.materialize_ssa_tile
ops (which don't do anything other than act as a placeholder to allow incremental rewrites, like unrealized conversion casts), become dead code and fold away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This idea is shown in: https://github.com/MacDue/llvm-project/blob/b9064adf0d6a53c61a907fc2e61f51997a651eab/mlir/test/Dialect/ArmSME/canonicalize.mlir (where the arm_sme.materialize_ssa_tile
op and the unused tile SSA values fold way).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Better naming suggestions here are welcome :))
Since llvm#73253 we now loops in SSA form for tiles (i.e. loops that take `iter_args` and yield a new tile), so this patch updates lowerings to use that. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.
…gs (#73922) Since #73253, loops over tiles in SSA form (i.e. loops that take `iter_args` and yield a new tile) are supported, so this patch updates ArmSME lowerings to this form. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms. Example: IR before: ```mlir scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> } // ... later use %tile ``` IR now: ```mlir %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { %tile_update = arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %iter_tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> scf.yield %tile_update : vector<[4]x[4]xi32> } // ... later use %broadcast_to_tile ```
This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:
-convert-vector-to-arm-sme
)scf.for
loops that yield tiles)immargs
as MLIR valuesimmarg
(so is required to be a compile-time constant),immargs
should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect)immargs
can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations)As part of this patch we bid farewell to the following operations:
These are now replaced with:
The new tile allocation works by operations implementing the
ArmSMETileOpInterface
. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile.Operations allocate a new tile by implementing...
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
...and returning what type of tile the op allocates (ZAB, ZAH, etc).
Operations that don't allocate a tile return
std::nullopt
(which is the default behaviour).Currently the following ops are defined as allocating:
Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases.
Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly created operations.