Skip to content

Commit 38098b4

Browse files
authored
[mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors (#117538)
These operations were introduced as counterparts to the following LLVM intrinsics: * `@llvm.masked.expandload.*`, * `@llvm.masked.compressstore.*`. Currently, there is minimal test coverage for scalable vector use cases involving these Ops (both LLVM and MLIR). Additionally, the verifier is flawed - it incorrectly allows mixing fixed-width and scalable vectors. To address these issues, scalable vector support for these Ops is being disabled for now. This decision can be revisited if a clear need arises for their use with scalable vectors in the future.
1 parent 5c181a9 commit 38098b4

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,7 +2086,7 @@ def Vector_ExpandLoadOp :
20862086
Vector_Op<"expandload">,
20872087
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
20882088
Variadic<Index>:$indices,
2089-
VectorOfNonZeroRankOf<[I1]>:$mask,
2089+
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
20902090
AnyVectorOfNonZeroRank:$pass_thru)>,
20912091
Results<(outs AnyVectorOfNonZeroRank:$result)> {
20922092

@@ -2119,6 +2119,8 @@ def Vector_ExpandLoadOp :
21192119
correspond to those of the `llvm.masked.expandload`
21202120
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
21212121

2122+
Note, at the moment this Op is only available for fixed-width vectors.
2123+
21222124
Examples:
21232125

21242126
```mlir
@@ -2153,7 +2155,7 @@ def Vector_CompressStoreOp :
21532155
Vector_Op<"compressstore">,
21542156
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21552157
Variadic<Index>:$indices,
2156-
VectorOfNonZeroRankOf<[I1]>:$mask,
2158+
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
21572159
AnyVectorOfNonZeroRank:$valueToStore)> {
21582160

21592161
let summary = "writes elements selectively from a vector as defined by a mask";
@@ -2185,6 +2187,8 @@ def Vector_CompressStoreOp :
21852187
correspond to those of the `llvm.masked.compressstore`
21862188
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
21872189

2190+
Note, at the moment this Op is only available for fixed-width vectors.
2191+
21882192
Examples:
21892193

21902194
```mlir

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ include "mlir/IR/DialectBase.td"
2424
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
2525
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
2626
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
27+
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
28+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
29+
CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
2730

2831
// Temporary vector type clone that allows gradual transition to 0-D vectors.
2932
// TODO: Remove this when all ops support 0-D vectors.
@@ -432,6 +435,10 @@ class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
432435
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
433436
"::mlir::VectorType">;
434437

438+
class FixedVectorOfNonZeroRankOf<list<Type> allowedTypes> :
439+
ShapedContainerType<allowedTypes, IsFixedVectorOfNonZeroRankTypePred,
440+
"fixed-length vector", "::mlir::VectorType">;
441+
435442
// Temporary vector type clone that allows gradual transition to 0-D vectors.
436443
// TODO: Remove this when all ops support 0-D vectors.
437444
class VectorOfAnyRankOf<list<Type> allowedTypes> :
@@ -660,6 +667,8 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
660667
// Unlike the following definitions, this one excludes 0-D vectors
661668
def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
662669

670+
def AnyFixedVectorOfNonZeroRank : FixedVectorOfNonZeroRankOf<[AnyType]>;
671+
663672
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
664673

665674
def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
15191519

15201520
// -----
15211521

1522+
func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
1523+
%c0 = arith.constant 0 : index
1524+
// expected-error@+1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
1525+
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
1526+
}
1527+
1528+
// -----
1529+
15221530
func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
15231531
%c0 = arith.constant 0 : index
15241532
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
15511559

15521560
// -----
15531561

1562+
func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
1563+
%c0 = arith.constant 0 : index
1564+
// expected-error@+1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
1565+
vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
1566+
}
1567+
1568+
// -----
1569+
15541570
func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
15551571
%c0 = arith.constant 0 : index
15561572
// expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}

0 commit comments

Comments
 (0)