Skip to content

[mlir][tosa] Add missing verifier check for tosa.reshape #109301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 20, 2024

Conversation

CoTinker
Copy link
Contributor

This PR adds a missing verifier check for tosa.reshape, ensuring that the number of elements in new_shape matches the number of elements in the input tensor. Fixes #108151 and fixes #107969.

This PR adds a missing verifier check for `tosa.reshape`,
ensuring that the number of elements in `new_shape` matches
the number of elements in the input tensor.
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

Changes

This PR adds a missing verifier check for tosa.reshape, ensuring that the number of elements in new_shape matches the number of elements in the input tensor. Fixes #108151 and fixes #107969.


Full diff: https://github.com/llvm/llvm-project/pull/109301.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+19-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+16)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..6dce3d03066c9a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -30,6 +30,8 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <numeric>
+
 using namespace mlir;
 using namespace mlir::tosa;
 
@@ -1015,12 +1017,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
                            << newShapeDim;
   }
 
-  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
+  if (inputType.hasStaticShape()) {
     int64_t inputElementsNum = inputType.getNumElements();
-    int64_t outputElementsNum = outputType.getNumElements();
-    if (inputElementsNum != outputElementsNum) {
+    if (outputType.hasStaticShape()) {
+      int64_t outputElementsNum = outputType.getNumElements();
+      if (inputElementsNum != outputElementsNum) {
+        return emitOpError() << "cannot reshape " << inputElementsNum
+                             << " elements into " << outputElementsNum;
+      }
+    }
+
+    int64_t newShapeElementsNum = std::accumulate(
+        getNewShape().begin(), getNewShape().end(), 1LL,
+        [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
+    bool isStaticNewShape =
+        llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+    if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
+        (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
       return emitOpError() << "cannot reshape " << inputElementsNum
-                           << " elements into " << outputElementsNum;
+                           << " elements into " << newShapeElementsNum;
     }
   }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 311fdb1226c523..e5c5b9b3663903 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -360,6 +360,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
 
 // -----
 
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+  // expected-error@+1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
+  // expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
+  return
+}
+
+// -----
+
 func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
   // expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@GeorgeARM GeorgeARM merged commit e6eb94d into llvm:main Sep 20, 2024
11 checks passed
@CoTinker CoTinker deleted the verify_newshape branch September 20, 2024 10:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment