-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Add verifier check for Concat Op #136047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis adds verifier check for Concat Op add tests in verifier.mlir Full diff: https://github.com/llvm/llvm-project/pull/136047.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..d9e77dd3f3770 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
<< " on operands 0 and " << operandNum;
}
}
+
+ // ERROR_IF(axis_sum != shape[axis]);
+ int64_t axis_sum = 0;
+ for (const auto &input : inputList) {
+ const ShapeAdaptor inputShape(input.getType());
+ if (inputShape.isDynamicDim(axis)) {
+ // make axis_sum negative to indicate invalid value
+ axis_sum = -1;
+ break;
+ }
+ axis_sum += inputShape.getDimSize(axis);
+ }
+ const ShapeAdaptor outputShape(outType);
+ if (axis_sum >= 0 && outputShape.hasRank() &&
+ !outputShape.isDynamicDim(axis) &&
+ axis_sum != outputShape.getDimSize(axis))
+ return emitOpError("requires sum of axis dimensions of input1 "
+ "equal to output axis dimension, got ")
+ << axis_sum << " and " << outputShape.getDimSize(axis);
}
return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fc98aa95ed5b3..1ff73bee3923d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
- return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
- // expected-error@+1 {{'tosa.concat' op expect at least one input}}
- %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
- %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
- %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..e6310fee22479 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
%2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
return %2 : tensor<f32>
}
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+ // expected-error@+1 {{'tosa.concat' op expect at least one input}}
+ %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+ // expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
+}
|
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis adds verifier check for Concat Op add tests in verifier.mlir Full diff: https://github.com/llvm/llvm-project/pull/136047.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..d9e77dd3f3770 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
<< " on operands 0 and " << operandNum;
}
}
+
+ // ERROR_IF(axis_sum != shape[axis]);
+ int64_t axis_sum = 0;
+ for (const auto &input : inputList) {
+ const ShapeAdaptor inputShape(input.getType());
+ if (inputShape.isDynamicDim(axis)) {
+ // make axis_sum negative to indicate invalid value
+ axis_sum = -1;
+ break;
+ }
+ axis_sum += inputShape.getDimSize(axis);
+ }
+ const ShapeAdaptor outputShape(outType);
+ if (axis_sum >= 0 && outputShape.hasRank() &&
+ !outputShape.isDynamicDim(axis) &&
+ axis_sum != outputShape.getDimSize(axis))
+ return emitOpError("requires sum of axis dimensions of input1 "
+ "equal to output axis dimension, got ")
+ << axis_sum << " and " << outputShape.getDimSize(axis);
}
return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fc98aa95ed5b3..1ff73bee3923d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
- return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
- // expected-error@+1 {{'tosa.concat' op expect at least one input}}
- %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
- %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
- %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..e6310fee22479 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
%2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
return %2 : tensor<f32>
}
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+ // expected-error@+1 {{'tosa.concat' op expect at least one input}}
+ %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+ // expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a nit, otherwise LGTM
Apologies, this needs an update due to conflicts |
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]> Change-Id: I53e41ca3c1f4ee48997c510fee2c16ed912dfaa0
rebased and resolved conflict |
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]>
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]>
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]>
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]>
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension
add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir