Skip to content

Commit e7b1522

Browse files
committed
address comments
1 parent 69e1891 commit e7b1522

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
392392
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
393393
AllTypesMatch<["tile", "result"]>,
394394
TypesMatchWith<
395-
"mask has i1 element type and same shape as result",
395+
"mask has i1 element type and is a slice of the result",
396396
"result", "mask",
397397
"VectorType("
398398
"VectorType::Builder("
@@ -434,7 +434,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
434434
```
435435
}];
436436
let arguments = (ins
437-
Arg<AnyMemRef, "the reference to load from">:$base, AnyVector:$mask,
437+
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
438438
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
439439
ArmSME_TileSliceLayoutAttr:$layout
440440
);

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
8080
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
8181
PatternRewriter &rewriter) const override {
8282
if (tileLoadOp.getMask())
83+
// TODO: add masked patterns.
8384
return rewriter.notifyMatchFailure(
8485
tileLoadOp, "op has mask, needs masked pattern(s)");
8586

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
159159

160160
func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
161161
%c0 = arith.constant 0 : index
162-
// expected-error@+1 {{op failed to verify that mask has i1 element type and same shape as result}}
162+
// expected-error@+1 {{op failed to verify that mask has i1 element type and is a slice of the result}}
163163
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[2]xi1>, vector<[16]x[16]xi8>
164164
return
165165
}

0 commit comments

Comments
 (0)