Skip to content

[mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors #117538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {

Expand Down Expand Up @@ -2117,6 +2117,8 @@ def Vector_ExpandLoadOp :
correspond to those of the `llvm.masked.expandload`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).

Note, at the moment this Op is only available for fixed-width vectors.

Examples:

```mlir
Expand Down Expand Up @@ -2151,7 +2153,7 @@ def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore)> {

let summary = "writes elements selectively from a vector as defined by a mask";
Expand Down Expand Up @@ -2183,6 +2185,8 @@ def Vector_CompressStoreOp :
correspond to those of the `llvm.masked.compressstore`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).

Note, at the moment this Op is only available for fixed-width vectors.

Examples:

```mlir
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
Comment on lines +27 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digression: I'd be nice to have isa<FixedVectorType>. I remember we talked about it a few months ago -- do you know what the outcome was?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented it and you've approved it:

:) I've not landed it yet - we've been discussing how "scalability" is modelled and I wanted to avoid merging it pre-maturely. And the life happened 🤷🏻‍♂️

I will land it this week, unless there's some new comments.


// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
Expand Down Expand Up @@ -432,6 +435,10 @@ class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
"::mlir::VectorType">;

class FixedVectorOfNonZeroRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsFixedVectorOfNonZeroRankTypePred,
"fixed-length vector", "::mlir::VectorType">;

// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
class VectorOfAnyRankOf<list<Type> allowedTypes> :
Expand Down Expand Up @@ -660,6 +667,8 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
// Unlike the following definitions, this one excludes 0-D vectors
def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;

def AnyFixedVectorOfNonZeroRank : FixedVectorOfNonZeroRankOf<[AnyType]>;

def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,

// -----

func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
}

// -----

func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
Expand Down Expand Up @@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1

// -----

func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
}

// -----

func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}
Expand Down