Skip to content

Commit b1cbf4a

Browse files
authored
[mlir][ArmSME] Add comments in tile-spills-and-fills.mlir (#91450)
* adds comments in tile-spills-and-fills.mlir * adds comments in ArmSMEIntrinsicOps.td * updates test in tile-spills-and-fills.mlir not to return 2D scalable vectors (e.g. vector<[4]x[4]xf32>) - that's not supported and not needed for that test
1 parent fe0b798 commit b1cbf4a

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ArmSME_IntrLoadStoreOp<string mnemonic>
115115
/*immArgPositions=*/[2],
116116
/*immArgAttrNames=*/["tile_id"]>;
117117

118-
// Loads
118+
// Loads (from memory to ZA tile slice)
119119
class ArmSME_IntrLoadOp<string mnemonic>
120120
: ArmSME_IntrLoadStoreOp<mnemonic>,
121121
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
@@ -134,7 +134,7 @@ def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
134134
def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
135135
def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
136136

137-
// Stores
137+
// Stores (ZA tile slice to memory)
138138
class ArmSME_IntrStoreOp<string mnemonic>
139139
: ArmSME_IntrLoadStoreOp<mnemonic>,
140140
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,

mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,32 @@ func.func @use_too_many_tiles() {
7272
// AFTER-LLVM-LOWERING-DAG: %[[C8:.*]] = arith.constant 8 : index
7373
// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale
7474
// AFTER-LLVM-LOWERING-DAG: %[[SVL_H:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
75+
76+
/// 0. Create an in-memory-tile
77+
/// Note: 16 is an in-memory tile ID, that is a tile ID >= 16
78+
7579
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]])
7680
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
7781
//
7882
// AFTER-LLVM-LOWERING-NOT: scf.for
79-
// Note: 17 is the mask for the 32-bit tile 0.
83+
84+
/// 1. The following instruciton corresponds to %0 after tile allocation
85+
/// Note: 17 is the mask for the 32-bit tile 0.
86+
8087
// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 17 : i32}>
8188
//
8289
// AFTER-LLVM-LOWERING-NOT: scf.for
83-
// Note: 34 is the mask for the 32-bit tile 1.
90+
91+
/// 2. The following instruciton corresponds to %1 after tile allocation
92+
/// Note: 34 is the mask for the 32-bit tile 1.
93+
8494
// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 34 : i32}>
85-
//
95+
96+
/// 3. swap(<in-memory-tile>, tile 0).
97+
/// This can be interpreted as spilling %0 (the 32-bit tile 0), so that
98+
/// %2 can be allocated a tile (16 bit tile 0). Note that this is
99+
/// swapping vector<[8]x[8]xi16> rather than vector<[4]x[4]xi32>.
100+
86101
// AFTER-LLVM-LOWERING: scf.for
87102
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
88103
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
@@ -92,8 +107,15 @@ func.func @use_too_many_tiles() {
92107
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
93108
// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
94109
// AFTER-LLVM-LOWERING-NEXT: }
95-
// Note: 85 is the mask for the 16-bit tile 0.
110+
111+
/// 4. The following instruciton corresponds to %3 after tile allocation
112+
/// Note: 85 is the mask for the 16-bit tile 0.
113+
96114
// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}>
115+
116+
/// 5. swap(<inMemoryTile>, tile 0)
117+
/// This can be interpreted as restoring %0.
118+
97119
// AFTER-LLVM-LOWERING: scf.for
98120
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
99121
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
@@ -116,7 +138,7 @@ func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf3
116138
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
117139
%mask = vector.constant_mask [4] : vector<[4]xi1>
118140
%loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
119-
return %loadSlice : vector<[4]x[4]xf32>
141+
"test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> ()
120142
}
121143
// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
122144
// AFTER-TILE-ALLOC: arm_sme.get_tile
@@ -133,22 +155,38 @@ func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf3
133155
// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
134156
// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
135157
//
158+
159+
/// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
160+
/// tile (vector<[4]x[4]xf32>)
161+
136162
// AFTER-LLVM-LOWERING: scf.for
137163
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
138164
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
139165
// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
140166
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
167+
// Read ZA tile slice -> vector
141168
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
169+
/// Load vector from memory -> ZA tile
142170
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
171+
/// Store ZA tile slice in memory
143172
// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
144173
// AFTER-LLVM-LOWERING-NEXT: }
174+
175+
/// 2. Load into %tile
145176
// AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}>
177+
178+
/// 3. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
179+
/// tile (vector<[4]x[4]xf32>)
180+
146181
// AFTER-LLVM-LOWERING: scf.for
147182
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
148183
// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
149184
// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
150185
// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
186+
/// Read ZA tile slice -> vector
151187
// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
188+
/// Load vector from memory -> ZA tile
152189
// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
190+
/// Store ZA tile slice in memory
153191
// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
154192
// AFTER-LLVM-LOWERING-NEXT: }

0 commit comments

Comments
 (0)