Skip to content

Commit 8ea260a

Browse files
authored
[mlir][ArmSME] Add mask operand to load_tile_slice (#70655)
1 parent 4b29e8c commit 8ea260a

File tree

7 files changed

+174
-127
lines changed

7 files changed

+174
-127
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
390390
}
391391

392392
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
393-
AllTypesMatch<["tile", "result"]>
393+
AllTypesMatch<["tile", "result"]>,
394+
TypesMatchWith<
395+
"mask has i1 element type and is a slice of the result",
396+
"result", "mask",
397+
"VectorType("
398+
"VectorType::Builder("
399+
"::llvm::cast<mlir::VectorType>($_self)"
400+
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
401+
")">,
394402
]> {
395403
let summary = "Tile slice load and update operation";
396404
let description = [{
@@ -406,23 +414,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
406414
dimensions since the operation is scalable, and the element type must be a
407415
scalar that matches the element type of the result.
408416

417+
An SSA value `mask` specifies to mask out elements read from the MemRef.
418+
The `mask` type is an `i1` vector with a shape that matches how elements
419+
are read from the MemRef.
420+
409421
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
410422
```mlir
411-
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
423+
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
412424
```
413425

414426
Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
415427
```mlir
416-
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
428+
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
417429
```
418430

419431
Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
420432
```mlir
421-
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
433+
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
422434
```
423435
}];
424436
let arguments = (ins
425-
Arg<AnyMemRef, "the reference to load from">:$base,
437+
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
426438
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
427439
ArmSME_TileSliceLayoutAttr:$layout
428440
);
@@ -438,8 +450,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
438450
}];
439451

440452
let assemblyFormat = [{
441-
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
442-
attr-dict `:` type($base) `,` type($result)
453+
$base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
454+
(`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
455+
type($result)
443456
}];
444457
}
445458

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
6060
///
6161
/// AFTER:
6262
/// ```mlir
63+
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
6364
/// %tile_id = arm_sme.get_tile_id : i32
6465
/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
6566
/// %vscale = vector.vscale
@@ -69,14 +70,20 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
6970
/// %svl_s = arith.muli %min_svl_s, %vscale : index
7071
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
7172
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
72-
/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
73+
/// %ptrue_s, %tile, %tile_slice_idx
74+
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
7375
/// }
7476
/// ```
7577
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
7678
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
7779

7880
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
7981
PatternRewriter &rewriter) const override {
82+
if (tileLoadOp.getMask())
83+
// TODO: add masked patterns.
84+
return rewriter.notifyMatchFailure(
85+
tileLoadOp, "op has mask, needs masked pattern(s)");
86+
8087
OpBuilder::InsertionGuard g(rewriter);
8188
auto loc = tileLoadOp.getLoc();
8289
auto tileType = tileLoadOp.getVectorType();
@@ -109,6 +116,12 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
109116

110117
rewriter.setInsertionPointToStart(forOp.getBody());
111118

119+
// Create an 'all true' predicate for the tile slice.
120+
auto predicateType =
121+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
122+
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
123+
loc, DenseElementsAttr::get(predicateType, true));
124+
112125
// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
113126
// tile.
114127
SmallVector<Value> memrefIndices;
@@ -117,8 +130,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
117130
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
118131
numTileSlices, memrefIndices, loc, rewriter);
119132
rewriter.create<arm_sme::LoadTileSliceOp>(
120-
loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
121-
tileSliceIndex, tileLoadOp.getLayout());
133+
loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
134+
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
122135

123136
rewriter.setInsertionPointAfter(forOp);
124137

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering
179179
loc, rewriter.getI32Type(), tileSlice);
180180

181181
// Create all active predicate mask.
182-
auto one = rewriter.create<arith::ConstantOp>(
183-
loc, rewriter.getI1Type(),
184-
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
185-
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
186-
/*scalableDims=*/{true});
187-
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
182+
auto maskOp = loadTileSliceOp.getMask();
188183

189184
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
190185
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
@@ -195,48 +190,48 @@ struct LoadTileSliceToArmSMELowering
195190
default:
196191
llvm_unreachable("unexpected element type!");
197192
case 8:
198-
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
199-
loc, allActiveMask, ptr, tileI32, tileSliceI32);
193+
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
194+
tileI32, tileSliceI32);
200195
break;
201196
case 16:
202-
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
203-
loc, allActiveMask, ptr, tileI32, tileSliceI32);
197+
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
198+
tileI32, tileSliceI32);
204199
break;
205200
case 32:
206-
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
207-
loc, allActiveMask, ptr, tileI32, tileSliceI32);
201+
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
202+
tileI32, tileSliceI32);
208203
break;
209204
case 64:
210-
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
211-
loc, allActiveMask, ptr, tileI32, tileSliceI32);
205+
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
206+
tileI32, tileSliceI32);
212207
break;
213208
case 128:
214-
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
215-
loc, allActiveMask, ptr, tileI32, tileSliceI32);
209+
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
210+
tileI32, tileSliceI32);
216211
break;
217212
}
218213
} else {
219214
switch (tileElementWidth) {
220215
default:
221216
llvm_unreachable("unexpected element type!");
222217
case 8:
223-
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
218+
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
224219
tileI32, tileSliceI32);
225220
break;
226221
case 16:
227-
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
222+
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
228223
tileI32, tileSliceI32);
229224
break;
230225
case 32:
231-
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
226+
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
232227
tileI32, tileSliceI32);
233228
break;
234229
case 64:
235-
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
230+
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
236231
tileI32, tileSliceI32);
237232
break;
238233
case 128:
239-
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
234+
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
240235
tileI32, tileSliceI32);
241236
break;
242237
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
22

3+
//===----------------------------------------------------------------------===//
4+
// arm_sme.tile_load
5+
//===----------------------------------------------------------------------===//
6+
37
// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
48
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
59
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
@@ -10,8 +14,9 @@
1014
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
1115
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
1216
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
17+
// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
1318
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
14-
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
19+
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
1520
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
1621
%c0 = arith.constant 0 : index
1722
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
2833
return
2934
}
3035

36+
//===----------------------------------------------------------------------===//
37+
// arm_sme.tile_store
38+
//===----------------------------------------------------------------------===//
39+
3140
// -----
3241

3342
// CHECK-LABEL: func.func @arm_sme_tile_store_hor(
@@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
5766
return
5867
}
5968

69+
//===----------------------------------------------------------------------===//
70+
// vector.print
71+
//===----------------------------------------------------------------------===//
72+
6073
// -----
6174

6275
func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)

0 commit comments

Comments
 (0)