You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering:
```
func.func @tosa(
%valuesIn : tensor<3x7x5xi32>,
%indices : tensor<3x6xi32>,
%input : tensor<3x6x5xi32>) ->
tensor<3x7x5xi32> {
%0 = "tosa.scatter"(%valuesIn, %indices, %input) :
(tensor<3x7x5xi32>,
tensor<3x6xi32>,
tensor<3x6x5xi32>) ->
(tensor<3x7x5xi32>)
return %0 : tensor<3x7x5xi32>
}
```
translates to
func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c6 = arith.constant 6 : index
%c2 = arith.constant 2 : index
%c5 = arith.constant 5 : index
%c0_0 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
%1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
%extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
%2 = arith.index_cast %extracted : i32 to index
%extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
%inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
scf.yield %inserted_slice : tensor<3x7x5xi32>
}
scf.yield %1 : tensor<3x7x5xi32>
}
return %0 : tensor<3x7x5xi32>
}
```
We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:
- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D151117
0 commit comments