Skip to content

[mlir][vector] Tighten the semantics of vector.gather #135749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

This patch restricts vector.gather to only accept tensors and memrefs
as valid sources. Currently, the source is typed as AnyShaped, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

(Note: the source %base here is a vector, which is incorrect.)

In contrast, vector.scatter currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.

This patch restricts `vector.gather` to only accept tensors and memrefs
as valid sources. Currently, the source is typed as `AnyShaped`, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

```mlir
  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
(Note: the source %base here is a vector, which is incorrect.)

In contrast, `vector.scatter` currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.
@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Andrzej Warzyński (banach-space)

Changes

This patch restricts vector.gather to only accept tensors and memrefs
as valid sources. Currently, the source is typed as AnyShaped, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector&lt;16xf32&gt;, vector&lt;16xi32&gt;, vector&lt;16xi1&gt;, vector&lt;16xf32&gt; into vector&lt;16xf32&gt;

(Note: the source %base here is a vector, which is incorrect.)

In contrast, vector.scatter currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.


Full diff: https://github.com/llvm/llvm-project/pull/135749.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+14-1)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..8ae5961af41bb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1971,7 +1971,7 @@ def Vector_GatherOp :
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
-    Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+    Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
 // Whether a type is a MemRefType.
 def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
 
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
 // Whether a type is an UnrankedMemRefType
 def IsUnrankedMemRefTypePred
         : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, HasValueSemanticsPred,
   "container with value semantics">;
 
+//===----------------------------------------------------------------------===//
 // Vector types.
+//===----------------------------------------------------------------------===//
 
 class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
 //===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
 //===----------------------------------------------------------------------===//
 
 // Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
                        "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
                        "nested tuple">;
 
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+                      "::mlir::ShapedType">;
+
 //===----------------------------------------------------------------------===//
 // Common type constraints
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..c6e780c6641fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
 
 // -----
 
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
 
 // -----
 
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                             %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+  vector.scatter %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This patch restricts vector.gather to only accept tensors and memrefs
as valid sources. Currently, the source is typed as AnyShaped, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector&lt;16xf32&gt;, vector&lt;16xi32&gt;, vector&lt;16xi1&gt;, vector&lt;16xf32&gt; into vector&lt;16xf32&gt;

(Note: the source %base here is a vector, which is incorrect.)

In contrast, vector.scatter currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.


Full diff: https://github.com/llvm/llvm-project/pull/135749.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+14-1)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..8ae5961af41bb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1971,7 +1971,7 @@ def Vector_GatherOp :
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
-    Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+    Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
 // Whether a type is a MemRefType.
 def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
 
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
 // Whether a type is an UnrankedMemRefType
 def IsUnrankedMemRefTypePred
         : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, HasValueSemanticsPred,
   "container with value semantics">;
 
+//===----------------------------------------------------------------------===//
 // Vector types.
+//===----------------------------------------------------------------------===//
 
 class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
 //===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
 //===----------------------------------------------------------------------===//
 
 // Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
                        "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
                        "nested tuple">;
 
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+                      "::mlir::ShapedType">;
+
 //===----------------------------------------------------------------------===//
 // Common type constraints
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..c6e780c6641fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
 
 // -----
 
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
 
 // -----
 
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                             %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+  vector.scatter %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir-ods

Author: Andrzej Warzyński (banach-space)

Changes

This patch restricts vector.gather to only accept tensors and memrefs
as valid sources. Currently, the source is typed as AnyShaped, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector&lt;16xf32&gt;, vector&lt;16xi32&gt;, vector&lt;16xi1&gt;, vector&lt;16xf32&gt; into vector&lt;16xf32&gt;

(Note: the source %base here is a vector, which is incorrect.)

In contrast, vector.scatter currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.


Full diff: https://github.com/llvm/llvm-project/pull/135749.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+14-1)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..8ae5961af41bb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1971,7 +1971,7 @@ def Vector_GatherOp :
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
-    Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+    Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
 // Whether a type is a MemRefType.
 def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
 
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
 // Whether a type is an UnrankedMemRefType
 def IsUnrankedMemRefTypePred
         : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, HasValueSemanticsPred,
   "container with value semantics">;
 
+//===----------------------------------------------------------------------===//
 // Vector types.
+//===----------------------------------------------------------------------===//
 
 class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
 //===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
 //===----------------------------------------------------------------------===//
 
 // Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
                        "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
                        "nested tuple">;
 
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+                      "::mlir::ShapedType">;
+
 //===----------------------------------------------------------------------===//
 // Common type constraints
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..c6e780c6641fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
 
 // -----
 
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
 
 // -----
 
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                             %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+  vector.scatter %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index

@banach-space banach-space merged commit 1e61b37 into llvm:main Apr 16, 2025
16 of 17 checks passed
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
This patch restricts `vector.gather` to only accept tensors and memrefs
as valid sources. Currently, the source is typed as `AnyShaped`, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

```mlir
  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
(Note: the source %base here is a vector, which is incorrect.)

In contrast, `vector.scatter` currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants