Skip to content

[mlir][tosa] Add verifier check for Concat Op #136047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged

Conversation

Tai78641
Copy link
Contributor

This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir


Full diff: https://github.com/llvm/llvm-project/pull/136047.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+19)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (-31)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+39)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..d9e77dd3f3770 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
                  << " on operands 0 and " << operandNum;
       }
     }
+
+    // ERROR_IF(axis_sum != shape[axis]);
+    int64_t axis_sum = 0;
+    for (const auto &input : inputList) {
+      const ShapeAdaptor inputShape(input.getType());
+      if (inputShape.isDynamicDim(axis)) {
+        // make axis_sum negative to indicate invalid value
+        axis_sum = -1;
+        break;
+      }
+      axis_sum += inputShape.getDimSize(axis);
+    }
+    const ShapeAdaptor outputShape(outType);
+    if (axis_sum >= 0 && outputShape.hasRank() &&
+        !outputShape.isDynamicDim(axis) &&
+        axis_sum != outputShape.getDimSize(axis))
+      return emitOpError("requires sum of axis dimensions of input1 "
+                         "equal to output axis dimension, got ")
+             << axis_sum << " and " << outputShape.getDimSize(axis);
   }
 
   return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fc98aa95ed5b3..1ff73bee3923d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 
 // -----
 
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
-  return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
-  // expected-error@+1 {{'tosa.concat' op expect at least one input}}
-  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
-  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..e6310fee22479 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
   %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
   return %2 : tensor<f32>
 }
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+  // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+  return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+  // expected-error@+1 {{'tosa.concat' op expect at least one input}}
+  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+  // expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir


Full diff: https://github.com/llvm/llvm-project/pull/136047.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+19)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (-31)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+39)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..d9e77dd3f3770 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
                  << " on operands 0 and " << operandNum;
       }
     }
+
+    // ERROR_IF(axis_sum != shape[axis]);
+    int64_t axis_sum = 0;
+    for (const auto &input : inputList) {
+      const ShapeAdaptor inputShape(input.getType());
+      if (inputShape.isDynamicDim(axis)) {
+        // make axis_sum negative to indicate invalid value
+        axis_sum = -1;
+        break;
+      }
+      axis_sum += inputShape.getDimSize(axis);
+    }
+    const ShapeAdaptor outputShape(outType);
+    if (axis_sum >= 0 && outputShape.hasRank() &&
+        !outputShape.isDynamicDim(axis) &&
+        axis_sum != outputShape.getDimSize(axis))
+      return emitOpError("requires sum of axis dimensions of input1 "
+                         "equal to output axis dimension, got ")
+             << axis_sum << " and " << outputShape.getDimSize(axis);
   }
 
   return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fc98aa95ed5b3..1ff73bee3923d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 
 // -----
 
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
-  return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
-  // expected-error@+1 {{'tosa.concat' op expect at least one input}}
-  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
-  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..e6310fee22479 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
   %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
   return %2 : tensor<f32>
 }
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+  // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+  return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+  // expected-error@+1 {{'tosa.concat' op expect at least one input}}
+  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+  // expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}

Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the PR!

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a nit, otherwise LGTM

@lhutton1
Copy link
Contributor

Apologies, this needs an update due to conflicts

This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions
is equal to the output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I53e41ca3c1f4ee48997c510fee2c16ed912dfaa0
@Tai78641
Copy link
Contributor Author

rebased and resolved conflict

@lhutton1 lhutton1 merged commit e98a61d into llvm:main Apr 24, 2025
11 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the
output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

Signed-off-by: Tai Ly <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the
output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

Signed-off-by: Tai Ly <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the
output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

Signed-off-by: Tai Ly <[email protected]>
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the
output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

Signed-off-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants