@@ -66,6 +66,15 @@ class HasMatchingMaskTypeConstraint<string vector, string mask> :
66
66
vector, mask,
67
67
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
68
68
69
+ class TileSliceMaskConstraint<string tile, string mask> :
70
+ TypesMatchWith<
71
+ "`" # mask # "` has i1 element type and the shape is a slice of `" # tile # "`",
72
+ tile, mask,
73
+ "VectorType("
74
+ "VectorType::Builder("
75
+ "::llvm::cast<mlir::VectorType>($_self)"
76
+ ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1)))">;
77
+
69
78
//===----------------------------------------------------------------------===//
70
79
// ArmSME attr definitions
71
80
//===----------------------------------------------------------------------===//
@@ -408,15 +417,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
408
417
}
409
418
410
419
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
411
- AllTypesMatch<["tile", "result"]>,
412
- TypesMatchWith<
413
- "mask has i1 element type and is a slice of the result",
414
- "result", "mask",
415
- "VectorType("
416
- "VectorType::Builder("
417
- "::llvm::cast<mlir::VectorType>($_self)"
418
- ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
419
- ")">,
420
+ AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
420
421
]> {
421
422
let summary = "Tile slice load and update operation";
422
423
let description = [{
@@ -432,9 +433,8 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
432
433
dimensions since the operation is scalable, and the element type must be a
433
434
scalar that matches the element type of the result.
434
435
435
- An SSA value `mask` specifies to mask out elements read from the MemRef.
436
- The `mask` type is an `i1` vector with a shape that matches how elements
437
- are read from the MemRef.
436
+ The provided `mask` is used to specify which elements of the tile slice
437
+ will be loaded.
438
438
439
439
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
440
440
```mlir
@@ -474,7 +474,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
474
474
}];
475
475
}
476
476
477
- def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
477
+ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
478
+ TileSliceMaskConstraint<"tile", "mask">
479
+ ]> {
478
480
let summary = "Tile slice store operation";
479
481
let description = [{
480
482
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
@@ -489,22 +491,26 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
489
491
dimensions since the operation is scalable, and the element type must be a
490
492
scalar that matches the element type of the input tile.
491
493
494
+ The provided `mask` is used to specify which elements of the tile slice
495
+ will be stored.
496
+
492
497
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
493
498
```mlir
494
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
499
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, % base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1 >, memref<?x?xi8>
495
500
```
496
501
497
502
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
498
503
```mlir
499
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
504
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, % base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1 >, memref<?x?xf32>
500
505
```
501
506
502
507
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
503
508
```mlir
504
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
509
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, % base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1 >, memref<?x?xi128>
505
510
```
506
511
}];
507
- let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
512
+ let arguments = (ins
513
+ SMETile:$tile, Index:$tile_slice_index, SVEPredicate:$mask,
508
514
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
509
515
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
510
516
);
@@ -518,8 +524,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
518
524
}];
519
525
520
526
let assemblyFormat = [{
521
- $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
522
- attr-dict `:` type($base) `,` type($tile)
527
+ $tile `,` $tile_slice_index `,` $mask `,` $ base `[` $indices `]` (`layout` `` $layout^)?
528
+ attr-dict `:` type($base) `,` type($mask) `,` type($ tile)
523
529
}];
524
530
}
525
531
0 commit comments