-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Add verifier for tosa.tile, fix shape inference crash #70972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds an verifier to `tosa.tile` which checks input/output ranks and the length of the `multiples` array. The patch also fixes a crash in the shape inference when an invalid `multiples` array is supplied. Fix llvm#70415
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Felix Schneider (ubfx) ChangesThis patch adds an verifier to Fix #70415 Full diff: https://github.com/llvm/llvm-project/pull/70972.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 81b9e93c2095f57..0a2f3271c37d212 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
);
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4ec6714a7e02a8b..375a7bbe38e8ec6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
outputShape.resize(multiples.size(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
- }
+ } else if (inputShape.getRank() != multiples.size())
+ return failure();
// Any non dynamic dimension can be multiplied to a known size.
outputShape.reserve(multiples.size());
@@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::TileOp::verify() {
+ ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
+ ShapedType outputType = llvm::cast<ShapedType>(getType());
+ auto multiples = getMultiples();
+
+ if (inputType.hasRank()) {
+ if (inputType.getRank() != multiples.size())
+ return emitOpError("expect 'multiples' array to have length ")
+ << inputType.getRank() << " but got " << multiples.size() << ".";
+ if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
+ return emitOpError("expect same input and output tensor rank.");
+ } else if (outputType.hasRank() && outputType.getRank() != multiples.size())
+ return emitOpError("expect 'multiples' array to have length ")
+ << outputType.getRank() << " but got " << multiples.size() << ".";
+
+ return success();
+}
+
bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != r.size() || l.size() != 1)
return false;
@@ -1830,9 +1849,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 102c9ed1578cde9..fd51d287bca0580 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
return
}
+
+// -----
+
+// CHECK-LABEL: @fold_tile_rank_zero
+func.func nested @fold_tile_rank_zero() -> tensor<i32> {
+ // CHECK-NOT: tosa.tile
+ %0 = tensor.empty() : tensor<i32>
+ %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8e23a1fde04bc82..4a517cdec1fd7bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
%1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
return
}
+
+// -----
+
+func.func @test_tile_invalid_multiples() {
+ %0 = tensor.empty() : tensor<4x31x31xf32>
+ // expected-error@+1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
+ %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch adds an verifier to
tosa.tile
which checks input/output ranks and the length of themultiples
array. The patch also fixes a crash in the shape inference when an invalidmultiples
array is supplied.Fix #70415