-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Tighten the semantics of vector.gather #135749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Tighten the semantics of vector.gather #135749
Conversation
This patch restricts `vector.gather` to only accept tensors and memrefs as valid sources. Currently, the source is typed as `AnyShaped`, which also includes vectors—allowing the following (invalid) construct to pass verification: ```mlir %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` (Note: the source %base here is a vector, which is incorrect.) In contrast, `vector.scatter` currently only accepts memrefs, so some asymmetry remains between the two ops. This PR is a step toward aligning their semantics.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Andrzej Warzyński (banach-space) ChangesThis patch restricts %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> (Note: the source %base here is a vector, which is incorrect.) In contrast, Full diff: https://github.com/llvm/llvm-project/pull/135749.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..8ae5961af41bb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1971,7 +1971,7 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
// Whether a type is an UnrankedMemRefType
def IsUnrankedMemRefTypePred
: CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
"container with value semantics">;
+//===----------------------------------------------------------------------===//
// Vector types.
+//===----------------------------------------------------------------------===//
class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
//===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
//===----------------------------------------------------------------------===//
// Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
"getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
"nested tuple">;
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+ "::mlir::ShapedType">;
+
//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..c6e780c6641fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
// -----
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
// -----
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ vector.scatter %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesThis patch restricts %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> (Note: the source %base here is a vector, which is incorrect.) In contrast, Full diff: https://github.com/llvm/llvm-project/pull/135749.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..8ae5961af41bb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1971,7 +1971,7 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
// Whether a type is an UnrankedMemRefType
def IsUnrankedMemRefTypePred
: CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
"container with value semantics">;
+//===----------------------------------------------------------------------===//
// Vector types.
+//===----------------------------------------------------------------------===//
class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
//===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
//===----------------------------------------------------------------------===//
// Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
"getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
"nested tuple">;
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+ "::mlir::ShapedType">;
+
//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..c6e780c6641fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
// -----
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
// -----
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ vector.scatter %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
|
@llvm/pr-subscribers-mlir-ods Author: Andrzej Warzyński (banach-space) ChangesThis patch restricts %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> (Note: the source %base here is a vector, which is incorrect.) In contrast, Full diff: https://github.com/llvm/llvm-project/pull/135749.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..8ae5961af41bb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1971,7 +1971,7 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
// Whether a type is an UnrankedMemRefType
def IsUnrankedMemRefTypePred
: CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
"container with value semantics">;
+//===----------------------------------------------------------------------===//
// Vector types.
+//===----------------------------------------------------------------------===//
class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
//===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
//===----------------------------------------------------------------------===//
// Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
"getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
"nested tuple">;
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+ "::mlir::ShapedType">;
+
//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..c6e780c6641fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
// -----
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
// -----
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ vector.scatter %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
|
This patch restricts `vector.gather` to only accept tensors and memrefs as valid sources. Currently, the source is typed as `AnyShaped`, which also includes vectors—allowing the following (invalid) construct to pass verification: ```mlir %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` (Note: the source %base here is a vector, which is incorrect.) In contrast, `vector.scatter` currently only accepts memrefs, so some asymmetry remains between the two ops. This PR is a step toward aligning their semantics.
This patch restricts
vector.gather
to only accept tensors and memrefsas valid sources. Currently, the source is typed as
AnyShaped
, whichalso includes vectors—allowing the following (invalid) construct to pass
verification:
(Note: the source %base here is a vector, which is incorrect.)
In contrast,
vector.scatter
currently only accepts memrefs, so someasymmetry remains between the two ops. This PR is a step toward aligning
their semantics.