Skip to content

Commit 11976c2

Browse files
zero9178rlavaee
authored andcommitted
[mlir][tensor] Relax input type requirement on tensor.splat (llvm#145893)
`tensor.splat` is currently restricted to only accepting input values that are of integer, index or float type. This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using `AnyType` for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.
1 parent 528f1ae commit 11976c2

File tree

4 files changed

+29
-11
lines changed

4 files changed

+29
-11
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
17711771
]> {
17721772
let summary = "tensor splat or broadcast operation";
17731773
let description = [{
1774-
Broadcast the operand to all elements of the result tensor. The operand is
1775-
required to be of integer/index/float type.
1774+
Broadcast the operand to all elements of the result tensor.
17761775

17771776
An additional argument of type `index` must be provided for each dynamic
17781777
dimension present in the result type.
@@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
17951794
```
17961795
}];
17971796

1798-
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
1799-
"integer/index/float type">:$input,
1797+
let arguments = (ins AnyType:$input,
18001798
Variadic<Index>:$dynamicSizes);
18011799
let results = (outs AnyRankedTensor:$aggregate);
18021800

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
615615

616616
// -----
617617

618+
// CHECK-LABEL: func @tensor.splat_other(
619+
// CHECK-SAME: %[[F:.*]]: !test.memref_element)
620+
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element>
621+
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
622+
// CHECK: %[[MAPPED:.*]] = linalg.map
623+
// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>)
624+
// CHECK: linalg.yield %[[F]]
625+
// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element>
626+
func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> {
627+
%t = tensor.splat %f : tensor<10x2x4x!test.memref_element>
628+
return %t : tensor<10x2x4x!test.memref_element>
629+
}
630+
631+
// -----
632+
618633
// CHECK-LABEL: func @tensor.concat(
619634
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
620635
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
466466

467467
// -----
468468

469-
func.func @invalid_splat(%v : vector<8xf32>) {
470-
// expected-error@+1 {{must be integer/index/float type}}
471-
%w = tensor.splat %v : tensor<8xvector<8xf32>>
469+
// expected-note@+1 {{prior use here}}
470+
func.func @invalid_splat(%v : f32) {
471+
// expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
472+
%w = tensor.splat %v : tensor<1xi32>
472473
return
473474
}
474475

mlir/test/Dialect/Tensor/ops.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,17 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
313313
// -----
314314

315315
// CHECK-LABEL: func @test_splat_op
316-
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
317-
func.func @test_splat_op(%s : f32) {
318-
// CHECK: tensor.splat [[S]] : tensor<8xf32>
316+
// CHECK-SAME: %[[S:.*]]: f32
317+
// CHECK-SAME: %[[P:.*]]: !llvm.ptr
318+
func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
319+
// CHECK: tensor.splat %[[S]] : tensor<8xf32>
319320
%v = tensor.splat %s : tensor<8xf32>
320321

321-
// CHECK: tensor.splat [[S]] : tensor<4xf32>
322+
// CHECK: tensor.splat %[[S]] : tensor<4xf32>
322323
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
324+
325+
// CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr>
326+
%w = tensor.splat %p : tensor<8x!llvm.ptr>
323327
return
324328
}
325329

0 commit comments

Comments
 (0)