Skip to content

[mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors #117538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 29, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 25, 2024

These operations were introduced as counterparts to the following LLVM
intrinsics:

  • @llvm.masked.expandload.*,
  • @llvm.masked.compressstore.*.

Currently, there is minimal test coverage for scalable vector use cases
involving these Ops (both LLVM and MLIR). Additionally, the verifier is
flawed - it incorrectly allows mixing fixed-width and scalable vectors.

To address these issues, scalable vector support for these Ops is being
disabled for now. This decision can be revisited if a clear need arises
for their use with scalable vectors in the future.

@banach-space banach-space changed the title andrzej/disable compress expand [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors Nov 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Rename vector type TD definitions (nfc)
  • [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors

Full diff: https://github.com/llvm/llvm-project/pull/117538.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+27-27)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+21-15)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+27-17)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
-  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
   "a vector with length " # length,
   "::mlir::VectorType">;
 
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[AnyFloat]>:$src1,
-          ScalableVectorOf<[AnyFloat]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src2
   );
-  let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
   );
-  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"add">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fadd">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"mul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fmul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fsub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedUDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ConvertFromSvboolIntrOp :
   ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/2>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
 
 // Note: This multi-vector intrinsic requires SME2.
 def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/4>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2,
-                   Arg<AnyScalableVector, "v3">:$v3,
-                   Arg<AnyScalableVector, "v3">:$v4)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
 
 // Note: This intrinsic requires SME or SVE2.1.
 def PselIntrOp : ArmSVE_IntrOp<"psel",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cc4cafa869e63a..5911355abd5146 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
   let hasVerifier = 1;
 }
 
-def Vector_ShuffleOp :
-  Vector_Op<"shuffle", [Pure,
-     PredOpTrait<"first operand v1 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"second operand v2 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>,
-     InferTypeOpAdaptor]>,
-     Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
-                    DenseI64ArrayAttr:$mask)>,
-     Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+    : Vector_Op<
+          "shuffle",
+          [Pure,
+           PredOpTrait<"first operand v1 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 0>>,
+           PredOpTrait<"second operand v2 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 1>>,
+           InferTypeOpAdaptor]>,
+      Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+          DenseI64ArrayAttr:$mask)>,
+      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
   let description = [{
     The shuffle operation constructs a permutation (or duplication) of elements
@@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$pass_thru)>,
+    Results<(outs AnyFixedVector:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
@@ -2115,6 +2117,8 @@ def Vector_ExpandLoadOp :
     correspond to those of the `llvm.masked.expandload`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
 
+    Note, at the moment this Op is only available for fixed-width vectors.
+
     Examples:
 
     ```mlir
@@ -2149,8 +2153,8 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$valueToStore)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -2181,6 +2185,8 @@ def Vector_CompressStoreOp :
     correspond to those of the `llvm.masked.compressstore`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
 
+    Note that the index increment is done conditionally.
+
     Examples:
 
     ```mlir
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..ef1645117a7280 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,13 +24,16 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                             CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsFixedVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                                 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+                                 CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                   !::llvm::cast<VectorType>($_self).isScalable()}]>;
 
 // Whether a type is a scalable VectorType.
@@ -432,17 +435,21 @@ class VectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                       "::mlir::VectorType">;
 
+class FixedVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
+          "fixed-length vector", "::mlir::VectorType">;
+
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
 
-class FixedVectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
           "fixed-length vector", "::mlir::VectorType">;
 
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
           "scalable vector", "::mlir::VectorType">;
 
@@ -467,7 +474,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedRanks` list
 class IsFixedVectorOfRankPred<list<int> allowedRanks> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedRanks,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                            == }]
@@ -509,8 +516,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
 // the type is from the given `allowedTypes` list
 class FixedVectorOfRankAndType<list<int> allowedRanks,
                           list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
-  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+  FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
@@ -525,7 +532,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedLengths` list
 class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedLengths,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                            == }]
@@ -612,8 +619,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
                                  list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
-  FixedVectorOf<allowedTypes>.summary #
+  [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+  FixedVectorOfAnyRank<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -621,8 +628,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,
                                     list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
-  ScalableVectorOf<allowedTypes>.summary #
+  [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -632,10 +639,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
 class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
                                            list<int> allowedLengths,
                                            list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
    ScalableVectorOfLength<allowedLengths>],
   ScalableVectorOfRank<allowedRanks>.summary #
-  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -657,13 +664,16 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
    ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
   "::mlir::VectorType">;
 
+// Unlike the following definitions, this one excludes 0-D vectors
 def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
-def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
 def AnyFixedVector : FixedVectorOf<[AnyType]>;
 
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
+
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
+
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
 
 // Shaped types.
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 0c093b0ccff141..ae336d4f5ddb8d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 // -----
 
+func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+}
+
+// -----
+
 func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
 
 // -----
 
+func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+}
+
+// -----
+
 func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}

@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2024

@llvm/pr-subscribers-mlir-ods

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Rename vector type TD definitions (nfc)
  • [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors

Full diff: https://github.com/llvm/llvm-project/pull/117538.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+27-27)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+21-15)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+27-17)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
-  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
   "a vector with length " # length,
   "::mlir::VectorType">;
 
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[AnyFloat]>:$src1,
-          ScalableVectorOf<[AnyFloat]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src2
   );
-  let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
   );
-  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"add">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fadd">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"mul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fmul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fsub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedUDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ConvertFromSvboolIntrOp :
   ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/2>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
 
 // Note: This multi-vector intrinsic requires SME2.
 def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/4>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2,
-                   Arg<AnyScalableVector, "v3">:$v3,
-                   Arg<AnyScalableVector, "v3">:$v4)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
 
 // Note: This intrinsic requires SME or SVE2.1.
 def PselIntrOp : ArmSVE_IntrOp<"psel",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cc4cafa869e63a..5911355abd5146 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
   let hasVerifier = 1;
 }
 
-def Vector_ShuffleOp :
-  Vector_Op<"shuffle", [Pure,
-     PredOpTrait<"first operand v1 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"second operand v2 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>,
-     InferTypeOpAdaptor]>,
-     Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
-                    DenseI64ArrayAttr:$mask)>,
-     Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+    : Vector_Op<
+          "shuffle",
+          [Pure,
+           PredOpTrait<"first operand v1 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 0>>,
+           PredOpTrait<"second operand v2 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 1>>,
+           InferTypeOpAdaptor]>,
+      Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+          DenseI64ArrayAttr:$mask)>,
+      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
   let description = [{
     The shuffle operation constructs a permutation (or duplication) of elements
@@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$pass_thru)>,
+    Results<(outs AnyFixedVector:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
@@ -2115,6 +2117,8 @@ def Vector_ExpandLoadOp :
     correspond to those of the `llvm.masked.expandload`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
 
+    Note, at the moment this Op is only available for fixed-width vectors.
+
     Examples:
 
     ```mlir
@@ -2149,8 +2153,8 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$valueToStore)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -2181,6 +2185,8 @@ def Vector_CompressStoreOp :
     correspond to those of the `llvm.masked.compressstore`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
 
+    Note that the index increment is done conditionally.
+
     Examples:
 
     ```mlir
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..ef1645117a7280 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,13 +24,16 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                             CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsFixedVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                                 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+                                 CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                   !::llvm::cast<VectorType>($_self).isScalable()}]>;
 
 // Whether a type is a scalable VectorType.
@@ -432,17 +435,21 @@ class VectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                       "::mlir::VectorType">;
 
+class FixedVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
+          "fixed-length vector", "::mlir::VectorType">;
+
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
 
-class FixedVectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
           "fixed-length vector", "::mlir::VectorType">;
 
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
           "scalable vector", "::mlir::VectorType">;
 
@@ -467,7 +474,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedRanks` list
 class IsFixedVectorOfRankPred<list<int> allowedRanks> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedRanks,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                            == }]
@@ -509,8 +516,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
 // the type is from the given `allowedTypes` list
 class FixedVectorOfRankAndType<list<int> allowedRanks,
                           list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
-  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+  FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
@@ -525,7 +532,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedLengths` list
 class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedLengths,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                            == }]
@@ -612,8 +619,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
                                  list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
-  FixedVectorOf<allowedTypes>.summary #
+  [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+  FixedVectorOfAnyRank<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -621,8 +628,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,
                                     list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
-  ScalableVectorOf<allowedTypes>.summary #
+  [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -632,10 +639,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
 class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
                                            list<int> allowedLengths,
                                            list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
    ScalableVectorOfLength<allowedLengths>],
   ScalableVectorOfRank<allowedRanks>.summary #
-  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -657,13 +664,16 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
    ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
   "::mlir::VectorType">;
 
+// Unlike the following definitions, this one excludes 0-D vectors
 def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
-def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
 def AnyFixedVector : FixedVectorOf<[AnyType]>;
 
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
+
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
+
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
 
 // Shaped types.
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 0c093b0ccff141..ae336d4f5ddb8d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 // -----
 
+func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+}
+
+// -----
+
 func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
 
 // -----
 
+func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+}
+
+// -----
+
 func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}

@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2024

@llvm/pr-subscribers-mlir-neon

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Rename vector type TD definitions (nfc)
  • [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors

Full diff: https://github.com/llvm/llvm-project/pull/117538.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+27-27)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+21-15)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+27-17)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
-  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
   "a vector with length " # length,
   "::mlir::VectorType">;
 
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[AnyFloat]>:$src1,
-          ScalableVectorOf<[AnyFloat]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src2
   );
-  let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
   );
-  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"add">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fadd">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"mul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fmul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fsub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedUDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ConvertFromSvboolIntrOp :
   ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/2>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
 
 // Note: This multi-vector intrinsic requires SME2.
 def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/4>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2,
-                   Arg<AnyScalableVector, "v3">:$v3,
-                   Arg<AnyScalableVector, "v3">:$v4)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
 
 // Note: This intrinsic requires SME or SVE2.1.
 def PselIntrOp : ArmSVE_IntrOp<"psel",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cc4cafa869e63a..5911355abd5146 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
   let hasVerifier = 1;
 }
 
-def Vector_ShuffleOp :
-  Vector_Op<"shuffle", [Pure,
-     PredOpTrait<"first operand v1 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"second operand v2 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>,
-     InferTypeOpAdaptor]>,
-     Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
-                    DenseI64ArrayAttr:$mask)>,
-     Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+    : Vector_Op<
+          "shuffle",
+          [Pure,
+           PredOpTrait<"first operand v1 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 0>>,
+           PredOpTrait<"second operand v2 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 1>>,
+           InferTypeOpAdaptor]>,
+      Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+          DenseI64ArrayAttr:$mask)>,
+      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
   let description = [{
     The shuffle operation constructs a permutation (or duplication) of elements
@@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$pass_thru)>,
+    Results<(outs AnyFixedVector:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
@@ -2115,6 +2117,8 @@ def Vector_ExpandLoadOp :
     correspond to those of the `llvm.masked.expandload`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
 
+    Note, at the moment this Op is only available for fixed-width vectors.
+
     Examples:
 
     ```mlir
@@ -2149,8 +2153,8 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$valueToStore)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -2181,6 +2185,8 @@ def Vector_CompressStoreOp :
     correspond to those of the `llvm.masked.compressstore`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
 
+    Note that the index increment is done conditionally.
+
     Examples:
 
     ```mlir
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..ef1645117a7280 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,13 +24,16 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                             CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsFixedVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                                 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+                                 CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                   !::llvm::cast<VectorType>($_self).isScalable()}]>;
 
 // Whether a type is a scalable VectorType.
@@ -432,17 +435,21 @@ class VectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                       "::mlir::VectorType">;
 
+class FixedVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
+          "fixed-length vector", "::mlir::VectorType">;
+
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
 
-class FixedVectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
           "fixed-length vector", "::mlir::VectorType">;
 
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
           "scalable vector", "::mlir::VectorType">;
 
@@ -467,7 +474,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedRanks` list
 class IsFixedVectorOfRankPred<list<int> allowedRanks> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedRanks,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                            == }]
@@ -509,8 +516,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
 // the type is from the given `allowedTypes` list
 class FixedVectorOfRankAndType<list<int> allowedRanks,
                           list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
-  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+  FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
@@ -525,7 +532,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedLengths` list
 class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedLengths,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                            == }]
@@ -612,8 +619,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
                                  list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
-  FixedVectorOf<allowedTypes>.summary #
+  [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+  FixedVectorOfAnyRank<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -621,8 +628,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,
                                     list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
-  ScalableVectorOf<allowedTypes>.summary #
+  [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -632,10 +639,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
 class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
                                            list<int> allowedLengths,
                                            list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
    ScalableVectorOfLength<allowedLengths>],
   ScalableVectorOfRank<allowedRanks>.summary #
-  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -657,13 +664,16 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
    ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
   "::mlir::VectorType">;
 
+// Unlike the following definitions, this one excludes 0-D vectors
 def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
-def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
 def AnyFixedVector : FixedVectorOf<[AnyType]>;
 
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
+
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
+
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
 
 // Shaped types.
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 0c093b0ccff141..ae336d4f5ddb8d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 // -----
 
+func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+}
+
+// -----
+
 func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
 
 // -----
 
+func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+}
+
+// -----
+
 func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}

@banach-space banach-space requested a review from nujaa November 25, 2024 10:53
These operations were introduced as counterparts to the following LLVM
intrinsics:

  * `@llvm.masked.expandload.*`,
  * `@llvm.masked.compressstore.*`.

Currently, there is minimal test coverage for scalable vector use cases
involving these Ops (both LLVM and MLIR). Additionally, the verifier is
flawed  —it incorrectly allows mixing fixed-width and scalable vectors.

To address these issues, scalable vector support for these Ops is being
disabled for now. This decision can be revisited if a clear need arises
for their use with scalable vectors in the future.

**NOTE:** Depends on llvm#117150 - please, only review the top commit.
@banach-space banach-space force-pushed the andrzej/disable_compress_expand branch from fcc429a to 95ae663 Compare November 27, 2024 09:37
Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I think there has been a copy paste issue :)

@@ -2183,6 +2185,8 @@ def Vector_CompressStoreOp :
correspond to those of the `llvm.masked.compressstore`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).

Note that the index increment is done conditionally.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why copying this from l. 2175 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean
Note, at the moment this Op is only available for fixed-width vectors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase gone wrong 🤦🏻 Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +27 to +29
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digression: I'd be nice to have isa<FixedVectorType>. I remember we talked about it a few months ago -- do you know what the outcome was?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented it and you've approved it:

:) I've not landed it yet - we've been discussing how "scalability" is modelled and I wanted to avoid merging it pre-maturely. And the life happened 🤷🏻‍♂️

I will land it this week, unless there's some new comments.

@banach-space
Copy link
Contributor Author

@nujaa I will assume that you are OK with this change and land it. But do let me know if you have further comments and I can address post-commit. Or even revert if need be.

Thanks!

@banach-space banach-space merged commit 38098b4 into llvm:main Nov 29, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants