-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[tosa] Fix crash in shape inference for tosa.transpose
#74367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes a crash in `TransposeOp::inferReturnTypeComponents()` when the supplied permutation tensor is rank-0. Also removes some dead code from the type inference function. Fix llvm#74237
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Felix Schneider (ubfx) ChangesFixes a crash in Fix #74237 Full diff: https://github.com/llvm/llvm-project/pull/74367.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f490cb1baa309..259fb6394669a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -983,6 +983,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor permsShape(adaptor.getPerms().getType());
+ // We cannot infer anything from a rank-0 "permutation" tensor.
+ if (permsShape.hasRank() && permsShape.getRank() == 0)
+ return failure();
+
// If input rank and permutation length is unknown, the output rank is
// unknown.
if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -997,15 +1001,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
return failure();
}
- // Without the input dims we cannot determine the output dim sizes but we
- // can determine the output rank.
SmallVector<int64_t> outputShape;
- if (!inputShape.hasRank()) {
- outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic);
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
- return success();
- }
-
// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f057431a841b5..c240f5334c149 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1310,3 +1310,14 @@ func.func @test_large_constant_permutation() {
return
}
+// -----
+
+// CHECK-LABEL: test_rank0_transpose_perms
+// Fail to infer the shape but not crash.
+func.func @test_rank0_transpose_perms() {
+ %14 = tensor.empty() : tensor<5x27xi64>
+ %cst = tensor.empty() : tensor<i32>
+ // CHECK: tosa.transpose
+ %72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
%14 = tensor.empty() : tensor<5x27xi64> | ||
%cst = tensor.empty() : tensor<i32> | ||
// CHECK: tosa.transpose | ||
%72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the description (in the TOSA Spec) of tosa.transpose
it seems like this op should be illegal. The permutation tensor is expected to be a rank-1 tensor with an element count equal to the rank of the input tensor.
Not crashing seems fine, but the verifier for tosa.transpose
should probably reject this op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, at the moment, transpose
doesn't have a verifier so this was just a crash-fix PR but I will add a verifier which checks for those thing in a separate PR.
This patch adds a verifier to `tosa.transpose` which fixes a crash. Related: llvm#74367 Fix llvm#74479
This patch adds a verifier to `tosa.transpose` which fixes a crash. Related: llvm/llvm-project#74367 Fix llvm/llvm-project#74479
Fixes a crash in
TransposeOp::inferReturnTypeComponents()
when the supplied permutation tensor is rank-0.Also removes some dead code from the type inference function.
Fix #74237