Skip to content

Commit 7563eb6

Browse files
authored
[tosa] Fix crash in shape inference for tosa.transpose (#74367)
Fixes a crash in `TransposeOp::inferReturnTypeComponents()` when the supplied permutation tensor is rank-0. Also removes some dead code from the type inference function. Fix #74237
1 parent ddebce7 commit 7563eb6

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
983983
ShapeAdaptor inputShape(adaptor.getInput1().getType());
984984
ShapeAdaptor permsShape(adaptor.getPerms().getType());
985985

986+
// We cannot infer anything from a rank-0 "permutation" tensor.
987+
if (permsShape.hasRank() && permsShape.getRank() == 0)
988+
return failure();
989+
986990
// If input rank and permutation length is unknown, the output rank is
987991
// unknown.
988992
if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -997,15 +1001,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
9971001
return failure();
9981002
}
9991003

1000-
// Without the input dims we cannot determine the output dim sizes but we
1001-
// can determine the output rank.
10021004
SmallVector<int64_t> outputShape;
1003-
if (!inputShape.hasRank()) {
1004-
outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic);
1005-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1006-
return success();
1007-
}
1008-
10091005
// Rank-0 means no permutations matter.
10101006
if (inputShape.getRank() == 0) {
10111007
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,3 +1310,14 @@ func.func @test_large_constant_permutation() {
13101310
return
13111311
}
13121312

1313+
// -----
1314+
1315+
// CHECK-LABEL: test_rank0_transpose_perms
1316+
// Fail to infer the shape but not crash.
1317+
func.func @test_rank0_transpose_perms() {
1318+
%14 = tensor.empty() : tensor<5x27xi64>
1319+
%cst = tensor.empty() : tensor<i32>
1320+
// CHECK: tosa.transpose
1321+
%72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
1322+
return
1323+
}

0 commit comments

Comments
 (0)