Skip to content

Commit 009cf93

Browse files
committed
make tile slice layout default
1 parent ea65d6d commit 009cf93

File tree

9 files changed

+306
-265
lines changed

9 files changed

+306
-265
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -259,35 +259,34 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
259259
let summary = "Tile load operation";
260260
let description = [{
261261
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
262-
with the shape defined by the 2D scalable vector type of the result tile. A
263-
tile slice layout attribute specifies whether the slices of the tile being
264-
loaded are horizontal or vertical. The slice of memory must be contiguous.
265-
The memref must be either rank 1 or rank 2 with dynamic dimensions, since
266-
the operation is scalable, and the element type must be a scalar that
267-
matches the element type of the result.
268-
269-
The default tile slice layout when lowering from higher-level dialects is
270-
horizontal.
271-
272-
Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
262+
with the shape defined by the 2D scalable vector type of the result tile.
263+
An optional tile slice layout attribute specifies whether the slices of the
264+
tile being loaded are horizontal (default) or vertical. The slice of memory
265+
must be contiguous. The memref must be either rank 1 or rank 2 with
266+
dynamic dimensions, since the operation is scalable, and the element type
267+
must be a scalar that matches the element type of the result.
268+
269+
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
273270
```mlir
274-
%tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
271+
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
275272
```
276273

277274
Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
278275
```mlir
279-
%tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
276+
%tile = arm_sme.tile_load %base[%c0, %c0], <ver> : memref<?x?xf32>, vector<[4]x[4]xf32>
280277
```
281278

282-
Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
279+
Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
283280
```mlir
284-
%tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
281+
%tile = arm_sme.tile_load %base[%c0, %c0], <hor> : memref<?x?xi128>, vector<[1]x[1]xi128>
285282
```
286283
}];
287284
let arguments = (ins
288-
ArmSME_TileSliceLayoutAttr:$layout,
289-
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
290-
Variadic<Index>:$indices);
285+
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
286+
Variadic<Index>:$indices,
287+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
288+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
289+
);
291290
let results = (outs SMETile:$result);
292291

293292
let extraClassDeclaration = [{
@@ -300,43 +299,42 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
300299
}];
301300

302301
let assemblyFormat =
303-
"$layout `,` $base `[` $indices `]` attr-dict "
304-
"`:` type($base) `,` type($result)";
302+
"$base `[` $indices `]` (`,` $layout^)? attr-dict "
303+
"`:` type($base) `,` type($result)";
305304
}
306305

307306
def TileStoreOp : ArmSME_Op<"tile_store"> {
308307
let summary = "Tile store operation";
309308
let description = [{
310309
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
311310
with the shape defined by the 2D scalable vector type of the tile being
312-
stored. A tile slice layout attribute specifies whether the slices of the
313-
tile being stored are horizontal or vertical. The slice of memory must be
314-
contiguous. The memref must be either rank 1 or rank 2 with dynamic
315-
dimensions, since the operation is scalable, and the element type must be a
316-
scalar that matches the element type of the result.
317-
318-
The default tile slice layout when lowering from higher-level dialects is
319-
horizontal.
311+
stored. An optional tile slice layout attribute specifies whether the
312+
slices of the tile being stored are horizontal (default) or vertical. The
313+
slice of memory must be contiguous. The memref must be either rank 1 or
314+
rank 2 with dynamic dimensions, since the operation is scalable, and the
315+
element type must be a scalar that matches the element type of the result.
320316

321-
Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
317+
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
322318
```mlir
323-
arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
319+
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
324320
```
325321

326322
Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
327323
```mlir
328-
arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
324+
arm_sme.tile_store %tile, %base[%c0, %c0], <ver> : vector<[4]x[4]xf32>, memref<?x?xf32>
329325
```
330326

331-
Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
327+
Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
332328
```mlir
333-
arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
329+
arm_sme.tile_store %tile, %base[%c0, %c0], <hor> : vector<[1]x[1]xi128>, memref<?x?xi128>
334330
```
335331
}];
336332
let arguments = (ins SMETile:$valueToStore,
337-
ArmSME_TileSliceLayoutAttr:$layout,
338-
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
339-
Variadic<Index>:$indices);
333+
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
334+
Variadic<Index>:$indices,
335+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
336+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
337+
);
340338
let extraClassDeclaration = [{
341339
MemRefType getMemRefType() {
342340
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -347,7 +345,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
347345
}];
348346

349347
let assemblyFormat =
350-
"$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
348+
"$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
351349
"`:` type($base) `,` type($valueToStore)";
352350
}
353351

@@ -359,23 +357,23 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
359357
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
360358
slice is defined by the dimension of the 2D scalable vector type pointed by
361359
the index. A tile slice index describes where in the input tile the tile
362-
slice is loaded to. A tile slice layout attribute specifies whether the
363-
tile slice being loaded at the given index is horizontal or vertical. The
364-
updated tile is returned as the result.
360+
slice is loaded to. An optional tile slice layout attribute specifies
361+
whether the tile slice being loaded at the given index is horizontal
362+
(default) or vertical. The updated tile is returned as the result.
365363

366364
The slice of memory read is defined by a base and indices and must be
367365
contiguous. The memref must be either rank 1 or rank 2, have dynamic
368366
dimensions since the operation is scalable, and the element type must be a
369367
scalar that matches the element type of the result.
370368

371-
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
369+
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
372370
```mlir
373-
%tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
371+
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
374372
```
375373

376374
Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
377375
```mlir
378-
%tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
376+
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <ver> : memref<?x?xf32>, vector<[4]x[4]xf32>
379377
```
380378

381379
Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
@@ -384,9 +382,11 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
384382
```
385383
}];
386384
let arguments = (ins
387-
ArmSME_TileSliceLayoutAttr:$layout,
388-
Arg<AnyMemRef, "the reference to load from">:$base,
389-
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
385+
Arg<AnyMemRef, "the reference to load from">:$base,
386+
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
387+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
388+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
389+
);
390390
let results = (outs SMETile:$result);
391391

392392
let extraClassDeclaration = [{
@@ -399,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
399399
}];
400400

401401
let assemblyFormat = [{
402-
$layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
402+
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
403403
attr-dict `:` type($base) `,` type($result)
404404
}];
405405
}
@@ -410,33 +410,36 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
410410
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
411411
slice is defined by the dimension of the 2D scalable vector type pointed by
412412
the index. A tile slice index describes where in the input tile the tile
413-
slice is stored from. A tile slice layout attribute specifies whether the
414-
tile slice being stored from the given index is horizontal or vertical.
413+
slice is stored from. An optional tile slice layout attribute specifies
414+
whether the tile slice being stored from the given index is horizontal
415+
(default) or vertical.
415416

416417
The slice of memory written is defined by a base and indices and must be
417418
contiguous. The memref must be either rank 1 or rank 2, have dynamic
418419
dimensions since the operation is scalable, and the element type must be a
419420
scalar that matches the element type of the input tile.
420421

421-
Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
422+
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
422423
```mlir
423-
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
424+
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
424425
```
425426

426427
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
427428
```mlir
428-
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
429+
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <ver>: vector<[4]x[4]xf32>, memref<?x?xf32>
429430
```
430431

431432
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
432433
```mlir
433-
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
434+
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <ver>: vector<[1]x[1]xi128>, memref<?x?xi128>
434435
```
435436
}];
436437
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
437-
ArmSME_TileSliceLayoutAttr:$layout,
438-
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
439-
Variadic<Index>:$indices);
438+
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
439+
Variadic<Index>:$indices,
440+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
441+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
442+
);
440443
let extraClassDeclaration = [{
441444
MemRefType getMemRefType() {
442445
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -447,7 +450,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
447450
}];
448451

449452
let assemblyFormat = [{
450-
$tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
453+
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
451454
attr-dict `:` type($base) `,` type($tile)
452455
}];
453456
}

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
117117
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
118118
numTileSlices, memrefIndices, loc, rewriter);
119119
rewriter.create<arm_sme::LoadTileSliceOp>(
120-
loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
121-
memrefIndices, tileSliceIndex);
120+
loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
121+
tileSliceIndex, tileLoadOp.getLayout());
122122

123123
rewriter.setInsertionPointAfter(forOp);
124124

@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
184184
numTileSlices, memrefIndices, loc, rewriter);
185185
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
186186
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
187-
tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
187+
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
188188

189189
return success();
190190
}

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ struct TransferWriteToArmSMELowering
8181
return failure();
8282

8383
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
84-
writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
85-
writeOp.getSource(), writeOp.getIndices());
84+
writeOp, writeOp.getVector(), writeOp.getSource(),
85+
writeOp.getIndices());
8686
return success();
8787
}
8888
};
@@ -97,8 +97,7 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
9797
return failure();
9898

9999
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
100-
load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
101-
load.getBase(), load.getIndices());
100+
load, load.getVectorType(), load.getBase(), load.getIndices());
102101

103102
return success();
104103
}
@@ -114,8 +113,7 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
114113
return failure();
115114

116115
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
117-
store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
118-
store.getBase(), store.getIndices());
116+
store, store.getValueToStore(), store.getBase(), store.getIndices());
119117

120118
return success();
121119
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
1212
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
1313
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
14-
// CHECK-NEXT: arm_sme.load_tile_slice <hor>, %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
14+
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
1515
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
1616
%c0 = arith.constant 0 : index
17-
%tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
17+
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
1818
return
1919
}
2020

2121
// -----
2222

2323
// CHECK-LABEL: @arm_sme_tile_load_ver
24-
// CHECK: arm_sme.load_tile_slice <ver>
24+
// CHECK: arm_sme.load_tile_slice {{.*}} <ver>
2525
func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
2626
%c0 = arith.constant 0 : index
27-
%tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
27+
%tile = arm_sme.tile_load %src[%c0, %c0], <ver> : memref<?x?xi32>, vector<[4]x[4]xi32>
2828
return
2929
}
3030

@@ -40,10 +40,10 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
4040
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
4141
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
4242
// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
43-
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], <hor>, %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
43+
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
4444
func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
4545
%c0 = arith.constant 0 : index
46-
arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
46+
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
4747
return
4848
}
4949

@@ -53,6 +53,6 @@ func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
5353
// CHECK: arm_sme.store_tile_slice {{.*}} <ver>
5454
func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
5555
%c0 = arith.constant 0 : index
56-
arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
56+
arm_sme.tile_store %tile, %dest[%c0, %c0], <ver> : memref<?x?xi32>, vector<[4]x[4]xi32>
5757
return
5858
}

mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
1616
%c0 = arith.constant 0 : index
1717
%tile = arm_sme.zero : vector<[16]x[16]xi8>
18-
arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
18+
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
1919
return
2020
}
2121

@@ -32,7 +32,7 @@ func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
3232
// CHECK: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
3333
func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
3434
%c0 = arith.constant 0 : index
35-
%tile = arm_sme.tile_load <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
35+
%tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
3636
return %tile : vector<[16]x[16]xi8>
3737
}
3838

@@ -46,6 +46,6 @@ func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
4646
// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
4747
func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
4848
%c0 = arith.constant 0 : index
49-
arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
49+
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
5050
return
5151
}

0 commit comments

Comments
 (0)