Skip to content

[mlir][tosa] Enhance verify checks for PAD Op #137177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025

Conversation

psunn
Copy link
Contributor

@psunn psunn commented Apr 24, 2025

  • add padding shape verification
  • add and update LIT test

@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Peng Sun (psunn)

Changes
  • add padding shape verification
  • add and update LIT test

Full diff: https://github.com/llvm/llvm-project/pull/137177.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+41-7)
  • (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+14-5)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c36c1074f5780..656a57971f634 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1515,15 +1515,49 @@ LogicalResult tosa::PadOp::verify() {
   if (!inputType || !outputType)
     return success();
 
-  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
+  auto inputRank = inputType.getRank();
+  auto outputRank = outputType.getRank();
+  if (inputRank != outputRank)
+    return emitOpError() << "expect same input and output tensor rank, but got "
+                         << "inputRank: " << inputRank
+                         << ", outputRank: " << outputRank;
+
+  DenseIntElementsAttr paddingAttr;
+  if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
+    return failure();
+
+  auto paddingValues = paddingAttr.getValues<APInt>();
+  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
+    return emitOpError() << "padding tensor must have " << inputRank
+                         << " * 2 = " << inputRank * 2 << " elements, but got "
+                         << paddingValues.size();
+
+  auto inputShape = inputType.getShape();
+  auto outputShape = outputType.getShape();
+
+  for (int64_t i = 0; i < inputRank; ++i) {
+    // Skip shape verification for dynamic dims
+    if (inputShape[i] == ShapedType::kDynamic ||
+        outputShape[i] == ShapedType::kDynamic)
+      continue;
+
+    int64_t padStart = paddingValues[i * 2].getSExtValue();
+    int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
 
-  if (inputType.getRank() != outputType.getRank())
-    return emitOpError() << "expect same input and output tensor rank.";
+    if (padStart < 0 || padEnd < 0) {
+      return emitOpError() << "padding values must be non-negative, got ["
+                           << padStart << ", " << padEnd << "] for dimension "
+                           << i;
+    }
 
-  if (paddingRank != inputType.getRank() * 2)
-    return emitOpError() << "expected padding tensor dim 0 to have size "
-                         << inputType.getRank() * 2
-                         << " (2*rank(shape1)) but got size " << paddingRank;
+    if (outputShape[i] != inputShape[i] + padStart + padEnd) {
+      return emitOpError() << "mismatch in output shape at dimension " << i
+                           << ": expected " << inputShape[i] << " + "
+                           << padStart << " + " << padEnd << " = "
+                           << (inputShape[i] + padStart + padEnd)
+                           << ", but got " << outputShape[i];
+    }
+  }
 
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 25e1aa195c3a0..8739f979d8d50 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -20,10 +20,10 @@ func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>)
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
   %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
-  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %1 : tensor<13x21x3xi8>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
+  return %1 : tensor<13x22x4xi8>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..1e7abd0532090 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -303,7 +303,7 @@ func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
+func.func @test_pad_padding_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
   %0 = tosa.pad %arg0, %arg1, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
@@ -312,9 +312,18 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_pad_const_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
   %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
+  return %1 : tensor<13x22x4xi8>
+}
+
+// -----
+
+func.func @test_pad_output_mismatch(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error@+1 {{mismatch in output shape at dimension 1: expected 21 + 0 + 1 = 22, but got 21}}
   %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
@@ -324,7 +333,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) ->
 func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
   %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
-  // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
+  // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank}}
   %1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
 }
 
@@ -341,7 +350,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
 func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
   %0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
-  // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
+  // expected-error@+1 {{'tosa.pad' op padding tensor must have 2 * 2 = 4 elements, but got 6}}
   %1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
   return
 }
@@ -361,7 +370,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
 func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
-  // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
+  // expected-error@+1 {{'tosa.pad' op padding tensor must have 3 * 2 = 6 elements, but got 4}}
   %1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
   return %1 : tensor<13x21x3xf32>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index c862ae375f33b..a7b4f2dc90e10 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -407,11 +407,11 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
   %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
-  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %1 : tensor<13x21x3xi8>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
+  return %1 : tensor<13x22x4xi8>
 }
 
 // -----

@wonjeon
Copy link
Contributor

wonjeon commented Apr 24, 2025

Thanks for the patch. LGTM. one small nit.

@psunn psunn force-pushed the psunn/pad_verify branch from 83fb0f9 to 572b6e8 Compare April 24, 2025 17:55
@psunn psunn requested review from wonjeon and lhutton1 April 24, 2025 18:20
@psunn psunn force-pushed the psunn/pad_verify branch from 572b6e8 to 5d47838 Compare April 25, 2025 15:50
Copy link

github-actions bot commented Apr 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

 * add padding shape verification
 * add checkErrorIfPad()
 * add and update LIT test

Change-Id: Ie77ba21d271362906618389cf90cf0af20e2fcae
Signed-off-by: Peng Sun <[email protected]>
@psunn psunn force-pushed the psunn/pad_verify branch from 5d47838 to cb29e03 Compare April 25, 2025 15:56
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @psunn!

@lhutton1 lhutton1 merged commit af32972 into llvm:main Apr 28, 2025
11 checks passed
jyli0116 pushed a commit to jyli0116/llvm-project that referenced this pull request Apr 28, 2025
* add padding shape verification
 * add and update LIT test

Signed-off-by: Peng Sun <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
* add padding shape verification
 * add and update LIT test

Signed-off-by: Peng Sun <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
* add padding shape verification
 * add and update LIT test

Signed-off-by: Peng Sun <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
* add padding shape verification
 * add and update LIT test

Signed-off-by: Peng Sun <[email protected]>
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
* add padding shape verification
 * add and update LIT test

Signed-off-by: Peng Sun <[email protected]>
lhutton1 pushed a commit to lhutton1/llvm-project that referenced this pull request Jun 11, 2025
 * add padding shape verification
 * add and update LIT test

(cherry picked from commit af32972)

Change-Id: Ie77ba21d271362906618389cf90cf0af20e2fcae
Signed-off-by: Peng Sun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants