-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Add comments in tile-spills-and-fills.mlir #91450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add comments in tile-spills-and-fills.mlir #91450
Conversation
* adds comments in tile-spills-and-fills.mlir * adds comments in ArmSMEIntrinsicOps.td * updates test in tile-spills-and-fills.mlir not to return 2D scalable vectors (e.g. vector<[4]x[4]xf32>) - that's not supported and not needed for that test
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/91450.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index f051e03efbcda..0e38325f9891a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -115,7 +115,7 @@ class ArmSME_IntrLoadStoreOp<string mnemonic>
/*immArgPositions=*/[2],
/*immArgAttrNames=*/["tile_id"]>;
-// Loads
+// Loads (from memory to ZA tile slice)
class ArmSME_IntrLoadOp<string mnemonic>
: ArmSME_IntrLoadStoreOp<mnemonic>,
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
@@ -134,7 +134,7 @@ def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
-// Stores
+// Stores (ZA tile slice to memory)
class ArmSME_IntrStoreOp<string mnemonic>
: ArmSME_IntrLoadStoreOp<mnemonic>,
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 7a9e6b4215754..ece6e6d2e7c12 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -76,13 +76,23 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
//
// AFTER-LLVM-LOWERING-NOT: scf.for
-// Note: 17 is the mask for the 32-bit tile 0.
+
+/// 1. Create/allocate %0
+/// Note: 17 is the mask for the 32-bit tile 0.
+
// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 17 : i32}>
//
// AFTER-LLVM-LOWERING-NOT: scf.for
-// Note: 34 is the mask for the 32-bit tile 1.
+
+/// 2. Create/allocate %1
+/// Note: 34 is the mask for the 32-bit tile 1.
+
// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 34 : i32}>
-//
+
+/// 3. Spill %0 (the 32-bit tile 0), so that %2 can be allocated (16 bit
+/// tile 0). Note that this is spilling vector<[8]x[8]xi16> rather than
+/// vector<[4]x[4]xi32>
+
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
@@ -92,8 +102,14 @@ func.func @use_too_many_tiles() {
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING-NEXT: }
-// Note: 85 is the mask for the 16-bit tile 0.
+
+/// 4. Create/allocate %2
+/// Note: 85 is the mask for the 16-bit tile 0.
+
// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}>
+
+/// 5. Re-load %0
+
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
@@ -116,7 +132,7 @@ func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf3
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
- return %loadSlice : vector<[4]x[4]xf32>
+ "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> ()
}
// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
// AFTER-TILE-ALLOC: arm_sme.get_tile
@@ -133,22 +149,38 @@ func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf3
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
//
+
+/// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
+/// tile (vector<[4]x[4]xf32>)
+
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
+// Read ZA tile slice -> vector
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
+/// Load vector from memory -> ZA tile
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
+/// Store ZA tile slice in memory
// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING-NEXT: }
+
+/// 2. Load into %tile
// AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}>
+
+/// 3. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
+/// tile (vector<[4]x[4]xf32>)
+
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
+/// Read ZA tile slice -> vector
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
+/// Load vector from memory -> ZA tile
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
+/// Store ZA tile slice in memory
// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING-NEXT: }
|
Thanks for the comments @MacDue. I've reworded my additions, please take another look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more nit, otherwise, LG
vectors (e.g. vector<[4]x[4]xf32>) - that's not supported and not
needed for that test