You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315.
This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
padding_value(%pad : f32)
inner_dims_pos = [1, 0]
inner_tiles = [2, %tile_dim_1]
into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```
Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%extracted_slice : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>)
permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32>
```
Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%padded : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32>
```
This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
0 commit comments