@@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
60
60
"::llvm::cast<VectorType>($_self).getElementType())"
61
61
".getWidth())">;
62
62
63
+ class HasMatchingMaskTypeConstraint<string vector, string mask> :
64
+ OptionalTypesMatchWith<
65
+ mask # " has i1 element type and same shape as " # vector,
66
+ vector, mask,
67
+ "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
68
+
63
69
//===----------------------------------------------------------------------===//
64
70
// ArmSME attr definitions
65
71
//===----------------------------------------------------------------------===//
@@ -259,14 +265,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
259
265
"result", "padding",
260
266
"::llvm::cast<VectorType>($_self).getElementType()"
261
267
>,
262
- OptionalTypesMatchWith<
263
- "mask has i1 element type and same shape as result",
264
- "result", "mask",
265
- "VectorType("
266
- "VectorType::Builder("
267
- "::llvm::cast<mlir::VectorType>($_self)"
268
- ").setElementType(IntegerType::get($_self.getContext(), 1)))"
269
- >,
268
+ HasMatchingMaskTypeConstraint<"result", "mask">,
270
269
PredOpTrait<
271
270
"both `padding` and `mask` should be provided or neither",
272
271
CPred<"bool(getPadding()) == bool(getMask())">
@@ -345,7 +344,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
345
344
"attr-dict `:` type($base) `,` type($result)";
346
345
}
347
346
348
- def TileStoreOp : ArmSME_Op<"tile_store"> {
347
+ def TileStoreOp : ArmSME_Op<"tile_store", [
348
+ AttrSizedOperandSegments,
349
+ HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
350
+ ]> {
349
351
let summary = "Tile store operation";
350
352
let description = [{
351
353
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -356,6 +358,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
356
358
rank 2 with dynamic dimensions, since the operation is scalable, and the
357
359
element type must be a scalar that matches the element type of the result.
358
360
361
+ An optional SSA value `mask` may be specified to mask out elements written
362
+ to the MemRef. The `mask` type is an `i1` vector of the same shape as the
363
+ vector type that matches how elements are written into the MemRef. Elements
364
+ whose corresponding mask element is `0` are masked out.
365
+
359
366
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
360
367
```mlir
361
368
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -370,10 +377,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
370
377
```mlir
371
378
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
372
379
```
380
+
381
+ Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
382
+ ```mlir
383
+ arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
384
+ ```
373
385
}];
374
386
let arguments = (ins SMETile:$valueToStore,
375
387
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
376
- Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
388
+ Variadic<Index>:$indices, Optional<AnyVector>:$mask,
389
+ ArmSME_TileSliceLayoutAttr:$layout
377
390
);
378
391
let extraClassDeclaration = [{
379
392
MemRefType getMemRefType() {
@@ -384,9 +397,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
384
397
}
385
398
}];
386
399
400
+ let builders = [
401
+ OpBuilder<(ins "Value":$valueToStore, "Value":$base,
402
+ "ValueRange":$indices), [{
403
+ build($_builder, $_state, valueToStore, base, indices, {});
404
+ }]>,
405
+ ];
406
+
387
407
let assemblyFormat =
388
- "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
389
- "`:` type($base) `,` type($valueToStore)";
408
+ "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (` layout` `` $layout^)?"
409
+ "attr-dict `:` type($base) `,` type($valueToStore)";
390
410
}
391
411
392
412
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
0 commit comments