-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Fix validation pass assert #134445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This fixes a validation pass assert when processing ops with quantized element types. The fix is to re-order the validation checking so that only ops with int/float operands and results pass the first stage of validation pass, so that the remaining checks do not need to handle quantized data types. Signed-off-by: Tai Ly <[email protected]> Change-Id: Ibd71c0de019f2395721035ed51b4983b67d56d61
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis fixes a validation pass assert when processing ops with quantized element types. Full diff: https://github.com/llvm/llvm-project/pull/134445.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3ec7354562d23..28e562c813eb3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1018,15 +1018,8 @@ void TosaValidation::runOnOperation() {
if (op->getDialect() != tosaDialect)
return;
- // Profile-Extension based validation should be performed at the beginning.
- if (strictOpSpecAlignment &&
- failed(profileComp.checkProfile(op, targetEnv)))
- return signalPassFailure();
-
- if (strictOpSpecAlignment &&
- failed(profileComp.checkExtension(op, targetEnv)))
- return signalPassFailure();
-
+ // perform valid element type check at the beginning to
+ // protect rest of code against quantized element types
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy)) {
@@ -1044,6 +1037,14 @@ void TosaValidation::runOnOperation() {
}
}
+ if (strictOpSpecAlignment &&
+ failed(profileComp.checkProfile(op, targetEnv)))
+ return signalPassFailure();
+
+ if (strictOpSpecAlignment &&
+ failed(profileComp.checkExtension(op, targetEnv)))
+ return signalPassFailure();
+
if (!allowInvalidOpDatatypeCombinations &&
failed(profileComp.checkInvalid(op))) {
op->emitOpError("illegal: operand/result data types not supported");
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8cf6d4b154792..12b2379a592c3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -253,6 +253,15 @@ func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
}
+// -----
+// CHECK-LABEL: conv2d_quant_any
+func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i32<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
+ %zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.conv2d' op is not profile-aligned: element type '!quant.any<i8<-8:7>>'}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i32<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+ return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+}
+
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
|
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis fixes a validation pass assert when processing ops with quantized element types. Full diff: https://github.com/llvm/llvm-project/pull/134445.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3ec7354562d23..28e562c813eb3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1018,15 +1018,8 @@ void TosaValidation::runOnOperation() {
if (op->getDialect() != tosaDialect)
return;
- // Profile-Extension based validation should be performed at the beginning.
- if (strictOpSpecAlignment &&
- failed(profileComp.checkProfile(op, targetEnv)))
- return signalPassFailure();
-
- if (strictOpSpecAlignment &&
- failed(profileComp.checkExtension(op, targetEnv)))
- return signalPassFailure();
-
+ // perform valid element type check at the beginning to
+ // protect rest of code against quantized element types
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy)) {
@@ -1044,6 +1037,14 @@ void TosaValidation::runOnOperation() {
}
}
+ if (strictOpSpecAlignment &&
+ failed(profileComp.checkProfile(op, targetEnv)))
+ return signalPassFailure();
+
+ if (strictOpSpecAlignment &&
+ failed(profileComp.checkExtension(op, targetEnv)))
+ return signalPassFailure();
+
if (!allowInvalidOpDatatypeCombinations &&
failed(profileComp.checkInvalid(op))) {
op->emitOpError("illegal: operand/result data types not supported");
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8cf6d4b154792..12b2379a592c3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -253,6 +253,15 @@ func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
}
+// -----
+// CHECK-LABEL: conv2d_quant_any
+func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i32<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
+ %zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.conv2d' op is not profile-aligned: element type '!quant.any<i8<-8:7>>'}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i32<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+ return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+}
+
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Tai78641!
This fixes a validation pass assert when processing ops with quantized element types.
The failure case is added to invalid.mlir
The fix is to re-order the validation checking so that only ops with int/float operands and results pass the first stage of validation pass, so that the remaining checks do not need to handle quantized data types.