Skip to content

Commit 9e44fa3

Browse files
[mlir][vector] Allow pointer types for vector.from_elements
1 parent 63f30d7 commit 9e44fa3

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
29202920
]> {
29212921
let summary = "vector splat or broadcast operation";
29222922
let description = [{
2923-
Broadcast the operand to all elements of the result vector. The operand is
2924-
required to be of integer/index/float type.
2923+
Broadcast the operand to all elements of the result vector. The type of the
2924+
operand must match the element type of the vector type.
29252925

29262926
Example:
29272927

@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
29312931
```
29322932
}];
29332933

2934-
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
2935-
"integer/index/float type">:$input);
2934+
let arguments = (ins AnyType:$input);
29362935
let results = (outs AnyVectorOfAnyRank:$aggregate);
29372936

29382937
let builders = [

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,13 +959,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
959959
}
960960

961961
// CHECK-LABEL: func @test_splat_op
962-
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
963-
func.func @test_splat_op(%s : f32) {
962+
// CHECK-SAME: [[S:%arg[0-9]+]]: f32, [[S2:%arg[0-9]+]]: !llvm.ptr<1>
963+
func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
964964
// CHECK: vector.splat [[S]] : vector<8xf32>
965965
%v = vector.splat %s : vector<8xf32>
966966

967967
// CHECK: vector.splat [[S]] : vector<4xf32>
968968
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
969+
970+
// CHECK: vector.splat [[S2]] : vector<16x!llvm.ptr<1>>
971+
%w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
969972
return
970973
}
971974

0 commit comments

Comments
 (0)