You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[mlir] [tosa] Bug fixes in shape inference pass (llvm#104146)
This change addresses 2 bugs in the TOSA shape inference pass
(`--tosa-infer-shapes`). The included unit test contains a detailed
description of the issues.
- Input IR
```
func.func @main(%arg0: tensor<1x2x8xf32>) {
%0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>
%c0 = arith.constant 0 : index
%dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>
%expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
%expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
return
}
```
- Output IR
```
module {
func.func @main(%arg0: tensor<1x2x8xf32>) {
%0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
// This cast was previously inserted between both 'tensor.expand_shape' ops.
%cast = tensor.cast %0 : tensor<1x2x8xf32> to tensor<?x2x8xf32>
%c0 = arith.constant 0 : index
%dim = tensor.dim %0, %c0 : tensor<1x2x8xf32>
// The operand of the first 'tensor.expand_shape' op was not previously updated
// from '%0' to '%cast' due to an invalidation of the iterator traversing the
// use list of the 'tosa.cast' op.
%expanded_0 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
%expanded_1 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
return
}
```
0 commit comments