Skip to content

[mlir][vector] Relax operand type restrictions for vector.splat #145517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

The vector type allows element types that implement the VectorElementTypeInterface. vector.splat should allow any element type that is supported by the vector type.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir-vector

Author: Matthias Springer (matthias-springer)

Changes

The vector type allows element types that implement the VectorElementTypeInterface. vector.splat should allow any element type that is supported by the vector type.


Full diff: https://github.com/llvm/llvm-project/pull/145517.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+3-4)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+5-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e62930a742d..d58ee84bee63d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
   ]> {
   let summary = "vector splat or broadcast operation";
   let description = [{
-    Broadcast the operand to all elements of the result vector. The operand is
-    required to be of integer/index/float type.
+    Broadcast the operand to all elements of the result vector. The type of the
+    operand must match the element type of the vector type.
 
     Example:
 
@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input);
+  let arguments = (ins AnyType:$input);
   let results = (outs AnyVectorOfAnyRank:$aggregate);
 
   let builders = [
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c59f7bd001905..c65c89c6e09d1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -959,13 +959,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
 }
 
 // CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[S:%arg[0-9]+]]: f32, [[S2:%arg[0-9]+]]: !llvm.ptr<1>
+func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
   // CHECK: vector.splat [[S]] : vector<8xf32>
   %v = vector.splat %s : vector<8xf32>
 
   // CHECK: vector.splat [[S]] : vector<4xf32>
   %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
+
+  // CHECK: vector.splat [[S2]] : vector<16x!llvm.ptr<1>>
+  %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
   return
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The vector type allows element types that implement the VectorElementTypeInterface. vector.splat should allow any element type that is supported by the vector type.


Full diff: https://github.com/llvm/llvm-project/pull/145517.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+3-4)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+5-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e62930a742d..d58ee84bee63d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
   ]> {
   let summary = "vector splat or broadcast operation";
   let description = [{
-    Broadcast the operand to all elements of the result vector. The operand is
-    required to be of integer/index/float type.
+    Broadcast the operand to all elements of the result vector. The type of the
+    operand must match the element type of the vector type.
 
     Example:
 
@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input);
+  let arguments = (ins AnyType:$input);
   let results = (outs AnyVectorOfAnyRank:$aggregate);
 
   let builders = [
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c59f7bd001905..c65c89c6e09d1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -959,13 +959,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
 }
 
 // CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[S:%arg[0-9]+]]: f32, [[S2:%arg[0-9]+]]: !llvm.ptr<1>
+func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
   // CHECK: vector.splat [[S]] : vector<8xf32>
   %v = vector.splat %s : vector<8xf32>
 
   // CHECK: vector.splat [[S]] : vector<4xf32>
   %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
+
+  // CHECK: vector.splat [[S2]] : vector<16x!llvm.ptr<1>>
+  %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
   return
 }
 

@matthias-springer matthias-springer changed the title [mlir][vector] Allow pointer types for vector.from_elements [mlir][vector] Remove type check from vector.from_elements Jun 24, 2025
@matthias-springer matthias-springer changed the title [mlir][vector] Remove type check from vector.from_elements [mlir][vector] Relax operand type restrictions for vector.from_elements Jun 24, 2025
@matthias-springer matthias-springer requested a review from grypp June 24, 2025 15:11
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

It would be good to add some negative tests and you also may want to update the title ;)

@dcaballe
Copy link
Contributor

LGTM. Could we also check that vector.broadcast supports the pointer case and add a test for it?
(I'm still wondering why we have the two ops...)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_from_elements branch from 9e44fa3 to 6f2a062 Compare June 25, 2025 06:54
@matthias-springer
Copy link
Member Author

matthias-springer commented Jun 25, 2025

you also may want to update the title

The title of the PR is accurate, isn't it?

@zero9178
Copy link
Member

you also may want to update the title

The title of the PR is accurate, isn't it?

The PR description and commit title refer to vector.from_elements rather than, vector.splat

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the typos adjusted

@matthias-springer matthias-springer changed the title [mlir][vector] Relax operand type restrictions for vector.from_elements [mlir][vector] Relax operand type restrictions for vector.splat Jun 25, 2025
@matthias-springer matthias-springer merged commit 4f9adb6 into main Jun 25, 2025
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/vector_from_elements branch June 25, 2025 08:45
Comment on lines +1978 to +1983
// expected-note @+1 {{prior use here}}
func.func @vector_splat_type_mismatch(%a: f32) {
// expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
%0 = vector.splat %a : vector<1xi32>
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ultra-nit] #145699 :)

anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…vm#145517)

The vector type allows element types that implement the
`VectorElementTypeInterface`. `vector.splat` should allow any element
type that is supported by the vector type.
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
…vm#145517)

The vector type allows element types that implement the
`VectorElementTypeInterface`. `vector.splat` should allow any element
type that is supported by the vector type.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants