Skip to content

[mlir][tosa] Fix validation pass assert #134445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Apr 4, 2025

This fixes a validation pass assert when processing ops with quantized element types.
The failure case is added to invalid.mlir
The fix is to re-order the validation checking so that only ops with int/float operands and results pass the first stage of validation pass, so that the remaining checks do not need to handle quantized data types.

This fixes a validation pass assert when processing ops with
quantized element types. The fix is to re-order the validation
checking so that only ops with int/float operands and results
pass the first stage of validation pass, so that the remaining
checks do not need to handle quantized data types.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: Ibd71c0de019f2395721035ed51b4983b67d56d61
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This fixes a validation pass assert when processing ops with quantized element types.
The failure case is added to invalid.mlir
The fix is to re-order the validation checking so that only ops with int/float operands and results pass the first stage of validation pass, so that the remaining checks do not need to handle quantized data types.


Full diff: https://github.com/llvm/llvm-project/pull/134445.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+10-9)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3ec7354562d23..28e562c813eb3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1018,15 +1018,8 @@ void TosaValidation::runOnOperation() {
     if (op->getDialect() != tosaDialect)
       return;
 
-    // Profile-Extension based validation should be performed at the beginning.
-    if (strictOpSpecAlignment &&
-        failed(profileComp.checkProfile(op, targetEnv)))
-      return signalPassFailure();
-
-    if (strictOpSpecAlignment &&
-        failed(profileComp.checkExtension(op, targetEnv)))
-      return signalPassFailure();
-
+    // perform valid element type check at the beginning to
+    // protect rest of code against quantized element types
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
       if (!isValidElementType(elementTy)) {
@@ -1044,6 +1037,14 @@ void TosaValidation::runOnOperation() {
       }
     }
 
+    if (strictOpSpecAlignment &&
+        failed(profileComp.checkProfile(op, targetEnv)))
+      return signalPassFailure();
+
+    if (strictOpSpecAlignment &&
+        failed(profileComp.checkExtension(op, targetEnv)))
+      return signalPassFailure();
+
     if (!allowInvalidOpDatatypeCombinations &&
         failed(profileComp.checkInvalid(op))) {
       op->emitOpError("illegal: operand/result data types not supported");
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8cf6d4b154792..12b2379a592c3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -253,6 +253,15 @@ func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>
   return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
 }
 
+// -----
+// CHECK-LABEL: conv2d_quant_any
+func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i32<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
+  %zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.conv2d' op is not profile-aligned: element type '!quant.any<i8<-8:7>>'}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i32<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+  return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+}
+
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {

@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This fixes a validation pass assert when processing ops with quantized element types.
The failure case is added to invalid.mlir
The fix is to re-order the validation checking so that only ops with int/float operands and results pass the first stage of validation pass, so that the remaining checks do not need to handle quantized data types.


Full diff: https://github.com/llvm/llvm-project/pull/134445.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+10-9)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3ec7354562d23..28e562c813eb3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1018,15 +1018,8 @@ void TosaValidation::runOnOperation() {
     if (op->getDialect() != tosaDialect)
       return;
 
-    // Profile-Extension based validation should be performed at the beginning.
-    if (strictOpSpecAlignment &&
-        failed(profileComp.checkProfile(op, targetEnv)))
-      return signalPassFailure();
-
-    if (strictOpSpecAlignment &&
-        failed(profileComp.checkExtension(op, targetEnv)))
-      return signalPassFailure();
-
+    // perform valid element type check at the beginning to
+    // protect rest of code against quantized element types
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
       if (!isValidElementType(elementTy)) {
@@ -1044,6 +1037,14 @@ void TosaValidation::runOnOperation() {
       }
     }
 
+    if (strictOpSpecAlignment &&
+        failed(profileComp.checkProfile(op, targetEnv)))
+      return signalPassFailure();
+
+    if (strictOpSpecAlignment &&
+        failed(profileComp.checkExtension(op, targetEnv)))
+      return signalPassFailure();
+
     if (!allowInvalidOpDatatypeCombinations &&
         failed(profileComp.checkInvalid(op))) {
       op->emitOpError("illegal: operand/result data types not supported");
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8cf6d4b154792..12b2379a592c3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -253,6 +253,15 @@ func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>
   return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
 }
 
+// -----
+// CHECK-LABEL: conv2d_quant_any
+func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i32<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
+  %zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.conv2d' op is not profile-aligned: element type '!quant.any<i8<-8:7>>'}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i32<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+  return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+}
+
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {

@Jerry-Ge Jerry-Ge requested review from lhutton1 and GeorgeARM April 4, 2025 21:41
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Tai78641!

@Jerry-Ge Jerry-Ge merged commit cb9afe5 into llvm:main Apr 7, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants