-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Enhance verify checks for PAD Op #137177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
psunn
commented
Apr 24, 2025
- add padding shape verification
- add and update LIT test
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Peng Sun (psunn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/137177.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c36c1074f5780..656a57971f634 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1515,15 +1515,49 @@ LogicalResult tosa::PadOp::verify() {
if (!inputType || !outputType)
return success();
- auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
+ auto inputRank = inputType.getRank();
+ auto outputRank = outputType.getRank();
+ if (inputRank != outputRank)
+ return emitOpError() << "expect same input and output tensor rank, but got "
+ << "inputRank: " << inputRank
+ << ", outputRank: " << outputRank;
+
+ DenseIntElementsAttr paddingAttr;
+ if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
+ return failure();
+
+ auto paddingValues = paddingAttr.getValues<APInt>();
+ if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
+ return emitOpError() << "padding tensor must have " << inputRank
+ << " * 2 = " << inputRank * 2 << " elements, but got "
+ << paddingValues.size();
+
+ auto inputShape = inputType.getShape();
+ auto outputShape = outputType.getShape();
+
+ for (int64_t i = 0; i < inputRank; ++i) {
+ // Skip shape verification for dynamic dims
+ if (inputShape[i] == ShapedType::kDynamic ||
+ outputShape[i] == ShapedType::kDynamic)
+ continue;
+
+ int64_t padStart = paddingValues[i * 2].getSExtValue();
+ int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
- if (inputType.getRank() != outputType.getRank())
- return emitOpError() << "expect same input and output tensor rank.";
+ if (padStart < 0 || padEnd < 0) {
+ return emitOpError() << "padding values must be non-negative, got ["
+ << padStart << ", " << padEnd << "] for dimension "
+ << i;
+ }
- if (paddingRank != inputType.getRank() * 2)
- return emitOpError() << "expected padding tensor dim 0 to have size "
- << inputType.getRank() * 2
- << " (2*rank(shape1)) but got size " << paddingRank;
+ if (outputShape[i] != inputShape[i] + padStart + padEnd) {
+ return emitOpError() << "mismatch in output shape at dimension " << i
+ << ": expected " << inputShape[i] << " + "
+ << padStart << " + " << padEnd << " = "
+ << (inputShape[i] + padStart + padEnd)
+ << ", but got " << outputShape[i];
+ }
+ }
return success();
}
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 25e1aa195c3a0..8739f979d8d50 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -20,10 +20,10 @@ func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>)
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
- %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %1 : tensor<13x21x3xi8>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
+ return %1 : tensor<13x22x4xi8>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..1e7abd0532090 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -303,7 +303,7 @@ func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
+func.func @test_pad_padding_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
%0 = tosa.pad %arg0, %arg1, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
@@ -312,9 +312,18 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_pad_const_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
+ return %1 : tensor<13x22x4xi8>
+}
+
+// -----
+
+func.func @test_pad_output_mismatch(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // expected-error@+1 {{mismatch in output shape at dimension 1: expected 21 + 0 + 1 = 22, but got 21}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
@@ -324,7 +333,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) ->
func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
+ // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
}
@@ -341,7 +350,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
+ // expected-error@+1 {{'tosa.pad' op padding tensor must have 2 * 2 = 4 elements, but got 6}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
return
}
@@ -361,7 +370,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
+ // expected-error@+1 {{'tosa.pad' op padding tensor must have 3 * 2 = 6 elements, but got 4}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index c862ae375f33b..a7b4f2dc90e10 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -407,11 +407,11 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
- %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %1 : tensor<13x21x3xi8>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
+ return %1 : tensor<13x22x4xi8>
}
// -----
|
Thanks for the patch. LGTM. one small nit. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
* add padding shape verification * add checkErrorIfPad() * add and update LIT test Change-Id: Ie77ba21d271362906618389cf90cf0af20e2fcae Signed-off-by: Peng Sun <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @psunn!
* add padding shape verification * add and update LIT test Signed-off-by: Peng Sun <[email protected]>
* add padding shape verification * add and update LIT test Signed-off-by: Peng Sun <[email protected]>
* add padding shape verification * add and update LIT test Signed-off-by: Peng Sun <[email protected]>
* add padding shape verification * add and update LIT test Signed-off-by: Peng Sun <[email protected]>
* add padding shape verification * add and update LIT test Signed-off-by: Peng Sun <[email protected]>
* add padding shape verification * add and update LIT test (cherry picked from commit af32972) Change-Id: Ie77ba21d271362906618389cf90cf0af20e2fcae Signed-off-by: Peng Sun <[email protected]>