Skip to content

Commit cb9afe5

Browse files
authored
[mlir][tosa] Fix validation pass assert (#134445)
This fixes a validation pass assert when processing ops with quantized element types. The failure case is added to invalid.mlir The fix is to re-order the validation checking so that only ops with int/float operands and results pass the first stage of validation pass, so that the remaining checks do not need to handle quantized data types. Signed-off-by: Tai Ly <[email protected]>
1 parent 1c8291f commit cb9afe5

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,15 +1018,8 @@ void TosaValidation::runOnOperation() {
10181018
if (op->getDialect() != tosaDialect)
10191019
return;
10201020

1021-
// Profile-Extension based validation should be performed at the beginning.
1022-
if (strictOpSpecAlignment &&
1023-
failed(profileComp.checkProfile(op, targetEnv)))
1024-
return signalPassFailure();
1025-
1026-
if (strictOpSpecAlignment &&
1027-
failed(profileComp.checkExtension(op, targetEnv)))
1028-
return signalPassFailure();
1029-
1021+
// perform valid element type check at the beginning to
1022+
// protect rest of code against quantized element types
10301023
for (Value operand : op->getOperands()) {
10311024
auto elementTy = getElementTypeOrSelf(operand);
10321025
if (!isValidElementType(elementTy)) {
@@ -1044,6 +1037,14 @@ void TosaValidation::runOnOperation() {
10441037
}
10451038
}
10461039

1040+
if (strictOpSpecAlignment &&
1041+
failed(profileComp.checkProfile(op, targetEnv)))
1042+
return signalPassFailure();
1043+
1044+
if (strictOpSpecAlignment &&
1045+
failed(profileComp.checkExtension(op, targetEnv)))
1046+
return signalPassFailure();
1047+
10471048
if (!allowInvalidOpDatatypeCombinations &&
10481049
failed(profileComp.checkInvalid(op))) {
10491050
op->emitOpError("illegal: operand/result data types not supported");

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,15 @@ func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>
253253
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
254254
}
255255

256+
// -----
257+
// CHECK-LABEL: conv2d_quant_any
258+
func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i32<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
259+
%zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
260+
// expected-error@+1 {{'tosa.conv2d' op is not profile-aligned: element type '!quant.any<i8<-8:7>>'}}
261+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i32<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
262+
return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
263+
}
264+
256265
// -----
257266

258267
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {

0 commit comments

Comments
 (0)