Skip to content

Commit ba2b21a

Browse files
authored
[mlir][ArmSME] Audit ArmSME load/store ops (#139573)
This patch updates the following ArmSME ops to require that input and output element types match: * `arm_sme.tile_load`, `arm_sme.tile_store`, `arm_sme.tile_load_slice`, `arm_sme.tile_store_slice`. In addition, it ensures that the base memref operand for `tile_load` and `tile_store` is always rank-2, aligning with the semantics of Arm SME tiles (always rank-2). This change is effectively a follow-up to #135151: * "[mlir][vector] Tighten the semantics of vector.{load|store}" The patch also updates `createLoadStoreForOverTileSlices` in ArmSMEToSCF.cpp to fail when processing invalid tile stores like the following: ```mlir arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32> ``` This particular change fixes #118769. As noted in the TODO, we should further extend op verification logic — I plan to address that in a follow-up patch.
1 parent ddf1249 commit ba2b21a

File tree

4 files changed

+94
-20
lines changed

4 files changed

+94
-20
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def CopyTileOp : ArmSME_Op<"copy_tile", [
317317
def TileLoadOp : ArmSME_Op<"tile_load", [
318318
ArmSMETileOpInterface,
319319
AttrSizedOperandSegments,
320+
AllElementTypesMatch<["result", "base"]>,
320321
OptionalTypesMatchWith<
321322
"padding type matches element type of result",
322323
"result", "padding",
@@ -369,7 +370,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
369370
```
370371
}];
371372
let arguments = (ins
372-
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
373+
Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
373374
Variadic<Index>:$indices,
374375
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
375376
ArmSME_TileSliceLayoutAttr:$layout
@@ -407,6 +408,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
407408
def TileStoreOp : ArmSME_Op<"tile_store", [
408409
ArmSMETileOpInterface,
409410
AttrSizedOperandSegments,
411+
AllElementTypesMatch<["valueToStore", "base"]>,
410412
HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
411413
]> {
412414
let summary = "Tile store operation";
@@ -443,7 +445,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
443445
```
444446
}];
445447
let arguments = (ins SMETile:$valueToStore,
446-
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
448+
Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
447449
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
448450
ArmSME_TileSliceLayoutAttr:$layout
449451
);
@@ -473,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
473475

474476
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
475477
ArmSMETileOpInterface,
478+
AllElementTypesMatch<["tile", "base"]>,
476479
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
477480
]> {
478481
let summary = "Tile slice load and update operation";
@@ -535,6 +538,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
535538

536539
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
537540
ArmSMETileOpInterface,
541+
AllElementTypesMatch<["tile", "base"]>,
538542
TileSliceMaskConstraint<"tile", "mask">
539543
]> {
540544
let summary = "Tile slice store operation";

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
3333
Value tileSliceIndex,
3434
Value tileSliceNumElts, Location loc,
3535
PatternRewriter &rewriter) {
36-
assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
36+
assert(rank == 2 && "memref has unexpected rank!");
3737
SmallVector<Value, 2> outIndices;
3838

3939
auto tileSliceOffset = tileSliceIndex;
40-
if (rank == 1)
41-
tileSliceOffset =
42-
rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
4340

4441
auto baseIndexPlusTileSliceOffset =
4542
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
4643
outIndices.push_back(baseIndexPlusTileSliceOffset);
47-
48-
if (rank == 2)
49-
outIndices.push_back(indices[1]);
44+
outIndices.push_back(indices[1]);
5045

5146
return outIndices;
5247
}
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
6055
makeLoopBody) {
6156
PatternRewriter::InsertionGuard guard(rewriter);
6257

58+
// TODO: This case should be captured and rejected by a verifier.
59+
if (memrefIndices.size() != 2)
60+
return rewriter.notifyMatchFailure(loc, "invalid number of indices");
61+
6362
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
6463
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
6564
auto vscale =

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
5050

5151
// -----
5252

53-
func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
53+
func.func @arm_sme_insert_tile_slice_i8__bad_vector_length(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
5454
%c0 = arith.constant 0 : index
5555
// expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
5656
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xi8> into vector<[16]x[16]xi8>
@@ -59,23 +59,40 @@ func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8
5959

6060
// -----
6161

62-
func.func @arm_sme_insert_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
62+
func.func @arm_sme_insert_tile_slice_f32__bad_vector_length(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
6363
%c0 = arith.constant 0 : index
6464
// expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
6565
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xf32> into vector<[4]x[4]xf32>
6666
return %0 : vector<[4]x[4]xf32>
6767
}
6868

69+
// -----
70+
71+
func.func @arm_sme_insert_tile_slice__bad_element_type(%vector : vector<[4]xf64>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
72+
%c0 = arith.constant 0 : index
73+
// expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
74+
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[4]xf64> into vector<[4]x[4]xf32>
75+
return %0 : vector<[4]x[4]xf32>
76+
}
77+
6978
//===----------------------------------------------------------------------===//
7079
// arm_sme.extract_tile_slice
7180
//===----------------------------------------------------------------------===//
7281

7382
// -----
7483

75-
func.func @arm_sme_extract_tile_slice__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
84+
func.func @arm_sme_extract_tile_slice__bad_result_length(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf32> {
85+
// expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
86+
%0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf32> from vector<[4]x[4]xf32>
87+
return %0 : vector<[2]xf32>
88+
}
89+
90+
// -----
91+
92+
func.func @arm_sme_extract_tile_slice__bad_result_element_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf64> {
7693
// expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
77-
%0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
78-
return %0 : vector<[2]xf64>
94+
%0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[4]xf64> from vector<[4]x[4]xf32>
95+
return %0 : vector<[4]xf64>
7996
}
8097

8198
//===----------------------------------------------------------------------===//
@@ -111,6 +128,24 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
111128
return
112129
}
113130

131+
// -----
132+
133+
func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
134+
%c0 = arith.constant 0 : index
135+
// expected-error@+1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
136+
%tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
137+
return
138+
}
139+
140+
// -----
141+
142+
func.func @arm_sme_tile_load__bad_element_type(%src : memref<?x?xi32>) {
143+
%c0 = arith.constant 0 : index
144+
// expected-error@+1 {{failed to verify that all of {result, base} have same element type}}
145+
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
146+
return
147+
}
148+
114149
//===----------------------------------------------------------------------===//
115150
// arm_sme.load_tile_slice
116151
//===----------------------------------------------------------------------===//
@@ -124,6 +159,15 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
124159
return
125160
}
126161

162+
// -----
163+
164+
func.func @arm_sme_load_tile_slice__bad_element_type(%src : memref<?x?xi32>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
165+
%c0 = arith.constant 0 : index
166+
// expected-error@+1 {{op failed to verify that all of {tile, base} have same element type}}
167+
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
168+
return
169+
}
170+
127171
//===----------------------------------------------------------------------===//
128172
// arm_sme.tile_store
129173
//===----------------------------------------------------------------------===//
@@ -138,6 +182,24 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
138182
return
139183
}
140184

185+
// -----
186+
187+
func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
188+
%c0 = arith.constant 0 : index
189+
// expected-error@+1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
190+
arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
191+
return
192+
}
193+
194+
// -----
195+
196+
func.func @arm_sme_tile_store__bad_element_type(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi32>) {
197+
%c0 = arith.constant 0 : index
198+
// expected-error@+1 {{op failed to verify that all of {valueToStore, base} have same element type}}
199+
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
200+
return
201+
}
202+
141203
//===----------------------------------------------------------------------===//
142204
// arm_sme.store_tile_slice
143205
//===----------------------------------------------------------------------===//
@@ -152,6 +214,15 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>,
152214
return
153215
}
154216

217+
// -----
218+
219+
func.func @arm_sme_store_tile_slice__bad_element_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi32>) -> () {
220+
%c0 = arith.constant 0 : index
221+
// expected-error@+1 {{op failed to verify that all of {tile, base} have same element type}}
222+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
223+
return
224+
}
225+
155226
//===----------------------------------------------------------------------===//
156227
// arm_sme.outerproduct
157228
//===----------------------------------------------------------------------===//

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @entry() {
1717
%za_s_size = arith.muli %svl_s, %svl_s : index
1818

1919
// Allocate memory.
20-
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
20+
%mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
2121

2222
// Fill each "row" of "mem1" with row number.
2323
//
@@ -29,15 +29,15 @@ func.func @entry() {
2929
// 3, 3, 3, 3
3030
//
3131
%init_0 = arith.constant 0 : i32
32-
scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
32+
scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
3333
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
34-
vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
34+
vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
3535
%val_next = arith.addi %val, %c1_i32 : i32
3636
scf.yield %val_next : i32
3737
}
3838

3939
// Load tile from "mem1" vertically.
40-
%0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?xi32>, vector<[4]x[4]xi32>
40+
%0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
4141

4242
// 1. ORIGINAL HORIZONTAL LAYOUT
4343
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
@@ -50,8 +50,8 @@ func.func @entry() {
5050
// CHECK-NEXT: ( 3, 3, 3, 3
5151
// CHECK: TILE END
5252
vector.print str "TILE BEGIN\n"
53-
scf.for %i = %c0 to %za_s_size step %svl_s {
54-
%tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
53+
scf.for %i = %c0 to %svl_s step %c1 {
54+
%tileslice = vector.load %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
5555
vector.print %tileslice : vector<[4]xi32>
5656
}
5757
vector.print str "TILE END\n"

0 commit comments

Comments
 (0)