Skip to content

Commit ea47887

Browse files
authored
Fix for TOSA-to-linalg lowering of tosa.transpose op (#72698)
The TOSA-to-linalg conversion of `tosa.transpose` contains a bug in the computation of the result tensor shape when using dynamic dimensions. This bug may have widespread implications in projects such as Tensorflow, where `tosa.transpose` is frequently generated. Consider the following TOSA code using only static dimensions. The code transposes a tensor of shape 10x11x12 into 12x10x11 by permuting dimensions [2, 0, 1] into [0, 1, 2]. ``` func.func @test_tosa_transpose(%input: tensor<10x11x12xf32>) -> tensor<12x10x11xf32> { %perms = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> %transposed = "tosa.transpose"(%input, %perms) : (tensor<10x11x12xf32>, tensor<3xi32>) -> tensor<12x10x11xf32> return %transposed : tensor<12x10x11xf32> } ``` The code is correctly lowered to: ``` #map = affine_map<(d0, d1, d2) -> (d1, d2, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module { func.func @test_tosa_transpose(%arg0: tensor<10x11x12xf32>) -> tensor<12x10x11xf32> { %empty = tensor.empty() : tensor<12x10x11xf32> %transposed = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<10x11x12xf32>) outs(%empty : tensor<12x10x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 } -> tensor<12x10x11xf32> return %transposed : tensor<12x10x11xf32> } } ``` Now let's make all dimensions dynamic in the TOSA code: ``` func.func @test_tosa_transpose(%input: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %perms = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> %transposed = "tosa.transpose"(%input, %perms) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> return %transposed : tensor<?x?x?xf32> } ``` The `tensor.empty()` op now needs additional information about the size of the output tensor, which is computed dynamically with a set of `tensor.dim` ops. The comments below assume an input tensor of size 10x11x12, as before. The code is lowered as: ``` #map = affine_map<(d0, d1, d2) -> (d1, d2, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module { func.func @test_tosa_transpose(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %arg0_dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> // Evaluates to 10 %arg0_dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> // Evaluates to 11 %arg0_dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> // Evaluates to 12 %empty = tensor.empty(%arg0_dim1, %arg0_dim2, %arg0_dim0) : tensor<?x?x?xf32> // Output of type tensor<11x12x10> WRONG! %transposed = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf32>) outs(%empty : tensor<?x?x?xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 } -> tensor<?x?x?xf32> return %transposed : tensor<?x?x?xf32> } } ``` The output tensor shape is dynamically computed as 11x12x10 instead of 12x10x11. Since the total size of the output tensor is still the same, the code does not segfault after bufferization. However, index computations are invalid and lead to SWAs.
1 parent b810b66 commit ea47887

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,12 +1072,11 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
10721072

10731073
SmallVector<AffineExpr, 2> inputExprs;
10741074
inputExprs.resize(resultTy.getRank());
1075-
auto operandTy = cast<ShapedType>(input.getType());
10761075
for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
10771076
auto index = permutation.index();
10781077
auto value = permutation.value().getZExtValue();
1079-
if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1080-
dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1078+
if (!resultTy.hasRank() || resultTy.isDynamicDim(index)) {
1079+
dynDims[index] = rewriter.create<tensor::DimOp>(loc, input, value);
10811080
}
10821081
inputExprs[value] = rewriter.getAffineDimExpr(index);
10831082
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -877,14 +877,14 @@ func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
877877
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
878878
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
879879

880-
// CHECK-LABEL: @test_transpose_dyn
880+
// CHECK-LABEL: @test_transpose_dyn_multiple_2d
881881
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
882-
func.func @test_transpose_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
882+
func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
883883
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
884-
// CHECK: %[[C0:.+]] = arith.constant 0
885-
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
886-
// CHECK: %[[C1:.+]] = arith.constant 1
887-
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
884+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
885+
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
886+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
887+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
888888
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
889889
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs([[OUT:%.+]] : tensor<?x?xf32>)
890890
// CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32)
@@ -896,6 +896,29 @@ func.func @test_transpose_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
896896

897897
// -----
898898

899+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
900+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
901+
902+
// CHECK-LABEL: @test_transpose_dyn_multiple_3d
903+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
904+
func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
905+
%0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
906+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
907+
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
908+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
909+
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
910+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
911+
// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
912+
// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
913+
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) {
914+
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[OUT0:.*]]: f32):
915+
// CHECK: linalg.yield %[[IN0]] : f32
916+
// CHECK: } -> tensor<?x?x?xf32>
917+
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
918+
return
919+
}
920+
921+
// -----
899922

900923
// CHECK-LABEL: @reduce_float
901924
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>

0 commit comments

Comments
 (0)