Skip to content

Commit 8f564e0

Browse files
authored
[mlir][ArmSME] Add mask operand to store_tile_slice (#70838)
1 parent 0d21436 commit 8f564e0

File tree

7 files changed

+156
-134
lines changed

7 files changed

+156
-134
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ class HasMatchingMaskTypeConstraint<string vector, string mask> :
6666
vector, mask,
6767
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
6868

69+
class TileSliceMaskConstraint<string tile, string mask> :
70+
TypesMatchWith<
71+
"`" # mask # "` has i1 element type and the shape is a slice of `" # tile # "`",
72+
tile, mask,
73+
"VectorType("
74+
"VectorType::Builder("
75+
"::llvm::cast<mlir::VectorType>($_self)"
76+
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1)))">;
77+
6978
//===----------------------------------------------------------------------===//
7079
// ArmSME attr definitions
7180
//===----------------------------------------------------------------------===//
@@ -408,15 +417,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
408417
}
409418

410419
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
411-
AllTypesMatch<["tile", "result"]>,
412-
TypesMatchWith<
413-
"mask has i1 element type and is a slice of the result",
414-
"result", "mask",
415-
"VectorType("
416-
"VectorType::Builder("
417-
"::llvm::cast<mlir::VectorType>($_self)"
418-
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
419-
")">,
420+
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
420421
]> {
421422
let summary = "Tile slice load and update operation";
422423
let description = [{
@@ -432,9 +433,8 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
432433
dimensions since the operation is scalable, and the element type must be a
433434
scalar that matches the element type of the result.
434435

435-
An SSA value `mask` specifies to mask out elements read from the MemRef.
436-
The `mask` type is an `i1` vector with a shape that matches how elements
437-
are read from the MemRef.
436+
The provided `mask` is used to specify which elements of the tile slice
437+
will be loaded.
438438

439439
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
440440
```mlir
@@ -474,7 +474,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
474474
}];
475475
}
476476

477-
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
477+
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
478+
TileSliceMaskConstraint<"tile", "mask">
479+
]> {
478480
let summary = "Tile slice store operation";
479481
let description = [{
480482
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
@@ -489,22 +491,26 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
489491
dimensions since the operation is scalable, and the element type must be a
490492
scalar that matches the element type of the input tile.
491493

494+
The provided `mask` is used to specify which elements of the tile slice
495+
will be stored.
496+
492497
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
493498
```mlir
494-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
499+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8>
495500
```
496501

497502
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
498503
```mlir
499-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
504+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32>
500505
```
501506

502507
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
503508
```mlir
504-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
509+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128>
505510
```
506511
}];
507-
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
512+
let arguments = (ins
513+
SMETile:$tile, Index:$tile_slice_index, SVEPredicate:$mask,
508514
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
509515
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
510516
);
@@ -518,8 +524,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
518524
}];
519525

520526
let assemblyFormat = [{
521-
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
522-
attr-dict `:` type($base) `,` type($tile)
527+
$tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)?
528+
attr-dict `:` type($base) `,` type($mask) `,` type($tile)
523529
}];
524530
}
525531

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,21 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
190190

191191
rewriter.setInsertionPointToStart(forOp.getBody());
192192

193+
// Create an 'all true' predicate for the tile slice.
194+
auto predicateType =
195+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
196+
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
197+
loc, DenseElementsAttr::get(predicateType, true));
198+
193199
SmallVector<Value> memrefIndices;
194200
auto tileSliceIndex = forOp.getInductionVar();
195201
getMemrefIndices(tileStoreOp.getIndices(),
196202
tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
197203
numTileSlices, memrefIndices, loc, rewriter);
198204
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
199205
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
200-
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
206+
allTruePredicate, tileStoreOp.getBase(), memrefIndices,
207+
tileStoreOp.getLayout());
201208

202209
return success();
203210
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,7 @@ struct StoreTileSliceToArmSMELowering
278278
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
279279
loc, rewriter.getI32Type(), tileSlice);
280280

281-
// Create all active predicate mask.
282-
auto one = rewriter.create<arith::ConstantOp>(
283-
loc, rewriter.getI1Type(),
284-
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
285-
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
286-
/*scalableDims=*/{true});
287-
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
281+
auto maskOp = storeTileSliceOp.getMask();
288282

289283
Value tileI32 = castTileIDToI32(tile, loc, rewriter);
290284
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
@@ -295,23 +289,23 @@ struct StoreTileSliceToArmSMELowering
295289
llvm_unreachable("unexpected element type!");
296290
case 8:
297291
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
298-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
292+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
299293
break;
300294
case 16:
301295
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
302-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
296+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
303297
break;
304298
case 32:
305299
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
306-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
300+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
307301
break;
308302
case 64:
309303
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
310-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
304+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
311305
break;
312306
case 128:
313307
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
314-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
308+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
315309
break;
316310
}
317311
} else {
@@ -320,23 +314,23 @@ struct StoreTileSliceToArmSMELowering
320314
llvm_unreachable("unexpected element type!");
321315
case 8:
322316
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
323-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
317+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
324318
break;
325319
case 16:
326320
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
327-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
321+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
328322
break;
329323
case 32:
330324
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
331-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
325+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
332326
break;
333327
case 64:
334328
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
335-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
329+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
336330
break;
337331
case 128:
338332
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
339-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
333+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
340334
break;
341335
}
342336
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
4848
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
4949
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
5050
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
51+
// CHECK: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
5152
// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
52-
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
53+
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
5354
func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
5455
%c0 = arith.constant 0 : index
5556
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

0 commit comments

Comments
 (0)