Skip to content

[mlir][ArmSME] Add optional mask operand to tile_store #70657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;

class HasMatchingMaskTypeConstraint<string vector, string mask> :
OptionalTypesMatchWith<
mask # " has i1 element type and same shape as " # vector,
vector, mask,
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
Comment on lines +63 to +67
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MacDue I've based on the constraint you're adding in #69604 but parameterized it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to rebase and de-duplicate this now :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)


//===----------------------------------------------------------------------===//
// ArmSME attr definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -259,14 +265,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()"
>,
OptionalTypesMatchWith<
"mask has i1 element type and same shape as result",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").setElementType(IntegerType::get($_self.getContext(), 1)))"
>,
HasMatchingMaskTypeConstraint<"result", "mask">,
PredOpTrait<
"both `padding` and `mask` should be provided or neither",
CPred<"bool(getPadding()) == bool(getMask())">
Expand Down Expand Up @@ -345,7 +344,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
"attr-dict `:` type($base) `,` type($result)";
}

def TileStoreOp : ArmSME_Op<"tile_store"> {
def TileStoreOp : ArmSME_Op<"tile_store", [
AttrSizedOperandSegments,
HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
]> {
let summary = "Tile store operation";
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
Expand All @@ -356,6 +358,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.

An optional `mask` may be provided, the shape of which corresponds to the
`tile`, and selects which elements of the tile will be stored.

Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
Expand All @@ -370,10 +375,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```

Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
Variadic<Index>:$indices, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand All @@ -384,9 +395,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];

let builders = [
OpBuilder<(ins "Value":$valueToStore, "Value":$base,
"ValueRange":$indices), [{
build($_builder, $_state, valueToStore, base, indices, {});
}]>,
];

let assemblyFormat =
"$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($valueToStore)";
"$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
"attr-dict `:` type($base) `,` type($valueToStore)";
}

def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Expand Down Expand Up @@ -595,12 +613,6 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
}

class HasMatchingMaskTypeConstraint<string operand> :
OptionalTypesMatchWith<
"shape of `" # operand # "Mask` matches `" # operand # "`",
operand, operand # "Mask",
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;

class OuterProductResultTileTypeConstraint<string operand> :
OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
"lhs", operand,
Expand All @@ -615,8 +627,8 @@ def OuterProductOp :
ArmSME_Op<"outerproduct", [Pure,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
HasMatchingMaskTypeConstraint<"lhs">,
HasMatchingMaskTypeConstraint<"rhs">,
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
PredOpTrait<
"both `lhsMask` and `rhsMask` should be provided or neither",
CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ struct TransferWriteToArmSMELowering
return failure();

rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(),
writeOp.getIndices());
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
writeOp.getMask());
return success();
}
};
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/ArmSME/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
return
}

//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
// expected-note@-2 {{prior use here}}
// expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}}
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/ArmSME/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre

// -----

func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
// CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}

// -----

/// Layout is optional and horizontal is the default, verify it's still parsed.
func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
// CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?

// -----

// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf64>,
// CHECK-SAME: %[[MASK:.*]]: vector<[2]x[2]xi1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] : memref<?x?xf64>, vector<[2]x[2]xf64>
func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
return
}

// -----

// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
Expand Down