-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][vector] Relax operand type restrictions for vector.splat
#145517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Relax operand type restrictions for vector.splat
#145517
Conversation
@llvm/pr-subscribers-mlir-vector Author: Matthias Springer (matthias-springer) ChangesThe vector type allows element types that implement the Full diff: https://github.com/llvm/llvm-project/pull/145517.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e62930a742d..d58ee84bee63d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
]> {
let summary = "vector splat or broadcast operation";
let description = [{
- Broadcast the operand to all elements of the result vector. The operand is
- required to be of integer/index/float type.
+ Broadcast the operand to all elements of the result vector. The type of the
+ operand must match the element type of the vector type.
Example:
@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
```
}];
- let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
- "integer/index/float type">:$input);
+ let arguments = (ins AnyType:$input);
let results = (outs AnyVectorOfAnyRank:$aggregate);
let builders = [
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c59f7bd001905..c65c89c6e09d1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -959,13 +959,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
}
// CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[S:%arg[0-9]+]]: f32, [[S2:%arg[0-9]+]]: !llvm.ptr<1>
+func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
// CHECK: vector.splat [[S]] : vector<8xf32>
%v = vector.splat %s : vector<8xf32>
// CHECK: vector.splat [[S]] : vector<4xf32>
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
+
+ // CHECK: vector.splat [[S2]] : vector<16x!llvm.ptr<1>>
+ %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
return
}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe vector type allows element types that implement the Full diff: https://github.com/llvm/llvm-project/pull/145517.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e62930a742d..d58ee84bee63d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
]> {
let summary = "vector splat or broadcast operation";
let description = [{
- Broadcast the operand to all elements of the result vector. The operand is
- required to be of integer/index/float type.
+ Broadcast the operand to all elements of the result vector. The type of the
+ operand must match the element type of the vector type.
Example:
@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
```
}];
- let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
- "integer/index/float type">:$input);
+ let arguments = (ins AnyType:$input);
let results = (outs AnyVectorOfAnyRank:$aggregate);
let builders = [
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c59f7bd001905..c65c89c6e09d1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -959,13 +959,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
}
// CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[S:%arg[0-9]+]]: f32, [[S2:%arg[0-9]+]]: !llvm.ptr<1>
+func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
// CHECK: vector.splat [[S]] : vector<8xf32>
%v = vector.splat %s : vector<8xf32>
// CHECK: vector.splat [[S]] : vector<4xf32>
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
+
+ // CHECK: vector.splat [[S2]] : vector<16x!llvm.ptr<1>>
+ %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
return
}
|
vector.from_elements
vector.from_elements
vector.from_elements
vector.from_elements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
It would be good to add some negative tests and you also may want to update the title ;)
LGTM. Could we also check that |
9e44fa3
to
6f2a062
Compare
The title of the PR is accurate, isn't it? |
The PR description and commit title refer to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with the typos adjusted
vector.from_elements
vector.splat
// expected-note @+1 {{prior use here}} | ||
func.func @vector_splat_type_mismatch(%a: f32) { | ||
// expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}} | ||
%0 = vector.splat %a : vector<1xi32> | ||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[ultra-nit] #145699 :)
…vm#145517) The vector type allows element types that implement the `VectorElementTypeInterface`. `vector.splat` should allow any element type that is supported by the vector type.
…vm#145517) The vector type allows element types that implement the `VectorElementTypeInterface`. `vector.splat` should allow any element type that is supported by the vector type.
The vector type allows element types that implement the
VectorElementTypeInterface
.vector.splat
should allow any element type that is supported by the vector type.