Skip to content

Commit cbe0686

Browse files
matthias-springerrlavaee
authored andcommitted
[mlir][vector] Relax operand type restrictions for vector.splat (llvm#145517)
The vector type allows element types that implement the `VectorElementTypeInterface`. `vector.splat` should allow any element type that is supported by the vector type.
1 parent 4907a37 commit cbe0686

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
29202920
]> {
29212921
let summary = "vector splat or broadcast operation";
29222922
let description = [{
2923-
Broadcast the operand to all elements of the result vector. The operand is
2924-
required to be of integer/index/float type.
2923+
Broadcast the operand to all elements of the result vector. The type of the
2924+
operand must match the element type of the vector type.
29252925

29262926
Example:
29272927

@@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
29312931
```
29322932
}];
29332933

2934-
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
2935-
"integer/index/float type">:$input);
2934+
let arguments = (ins AnyType:$input);
29362935
let results = (outs AnyVectorOfAnyRank:$aggregate);
29372936

29382937
let builders = [

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,15 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
19751975

19761976
// -----
19771977

1978+
// expected-note @+1 {{prior use here}}
1979+
func.func @vector_splat_type_mismatch(%a: f32) {
1980+
// expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
1981+
%0 = vector.splat %a : vector<1xi32>
1982+
return
1983+
}
1984+
1985+
// -----
1986+
19781987
//===----------------------------------------------------------------------===//
19791988
// vector.load
19801989
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
149149
}
150150

151151
// CHECK-LABEL: @vector_broadcast
152-
func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
152+
func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>) {
153153
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
154154
%0 = vector.broadcast %a : f32 to vector<f32>
155155
// CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
@@ -162,7 +162,9 @@ func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: ve
162162
%4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32>
163163
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
164164
%5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
165-
return %4 : vector<8x16xf32>
165+
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
166+
%6 = vector.broadcast %f : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
167+
return
166168
}
167169

168170
// CHECK-LABEL: @shuffle0D
@@ -959,13 +961,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
959961
}
960962

961963
// CHECK-LABEL: func @test_splat_op
962-
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
963-
func.func @test_splat_op(%s : f32) {
964-
// CHECK: vector.splat [[S]] : vector<8xf32>
964+
// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1>
965+
func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
966+
// CHECK: vector.splat %[[s]] : vector<8xf32>
965967
%v = vector.splat %s : vector<8xf32>
966968

967-
// CHECK: vector.splat [[S]] : vector<4xf32>
969+
// CHECK: vector.splat %[[s]] : vector<4xf32>
968970
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
971+
972+
// CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>>
973+
%w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
969974
return
970975
}
971976

0 commit comments

Comments
 (0)