-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Disallow invalid datatype combinations in the validation pass #131595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Disallow invalid datatype combinations in the validation pass #131595
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThis commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure. Note: We should aim to merge this after #131208. Full diff: https://github.com/llvm/llvm-project/pull/131595.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 69b827fe14dee..da187d8316989 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
// environment.
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+ LogicalResult checkInvalid(Operation *op);
template <typename T>
LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
private:
+ template <typename T>
+ FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
+ CheckCondition &condition);
+
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ed2c40598458c..9523146581f10 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -300,6 +300,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//
+template <typename T>
+FailureOr<SmallVector<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op,
+ CheckCondition &condition) {
+ const std::string opName = op->getName().getStringRef().str();
+ const auto complianceMap = getProfileComplianceMap<T>();
+ const auto it = complianceMap.find(opName);
+ if (it == complianceMap.end())
+ return {};
+
+ return findMatchedProfile<T>(op, it->second, condition);
+}
+
template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
Operation *op, const tosa::TargetEnv &targetEnv,
@@ -309,11 +322,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- auto opName = op->getName().getStringRef().str();
- auto compMap = getProfileComplianceMap<T>();
- auto it = compMap.find(opName);
-
- if (it == compMap.end()) {
+ CheckCondition condition = CheckCondition::invalid;
+ const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
+ if (failed(maybeOpRequiredMode)) {
// Operators such as variable and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
@@ -334,12 +345,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
return failure();
}
- CheckCondition condition = CheckCondition::invalid;
- // Find the profiles or extensions requirement according to the signature of
- // type of the operand list.
- SmallVector<T> opRequiredMode =
- findMatchedProfile<T>(op, it->second, condition);
-
+ // Find the required profiles or extensions according to the operand type
+ // combination.
+ const auto opRequiredMode = maybeOpRequiredMode.value();
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -419,6 +427,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
return success();
}
+LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
+ CheckCondition condition = CheckCondition::invalid;
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
+ !maybeProfDef.value().size() && !maybeExtDef.value().size())
+ return failure();
+
+ return success();
+}
+
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 79c13793d7713..7604a33218a88 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1042,6 +1042,11 @@ void TosaValidation::runOnOperation() {
}
}
+ if (failed(profileComp.checkInvalid(op))) {
+ op->emitOpError("illegal: operand/result data types not supported");
+ return signalPassFailure();
+ }
+
// Some uses of TOSA rely on the constant operands of particular
// operations.
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index ecd5c792e08b6..22b07e69d3b87 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -14,10 +14,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
// -----
// check that -tosa-validate level checking kick in
-func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}}
- %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
- return %0 : tensor<*xi8>
+ %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index fd9b3d5f23483..1598aefd95823 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -6,9 +6,9 @@
// -----
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %0 : tensor<13x21x3xi8>
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5b591e3c5f45c..5509b28a2f6d1 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1852,3 +1852,22 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
(tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_add_i1
+func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ // expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_out_i16
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+ // expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 13952716a9611..10140cc0a1e9b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
// -----
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
// expected-error@+1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %0 : tensor<13x21x3xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bdf18ec823128..0f469761d89e3 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
// -----
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
+func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
// expected-error@+1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
- return %0 : tensor<1x1x1x1x1x1x64xi16>
+ %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
+ return %0 : tensor<1x1x1x1x1x1x64xi32>
}
// -----
|
3bb61ab
to
edc7f32
Compare
… pass This commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure. Signed-off-by: Luke Hutton <[email protected]> Change-Id: Iab36dd84cdbf188015c80b066c321edbb2efc0ff
edc7f32
to
f891830
Compare
Friendly ping for review |
This commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure.
Note: We should aim to merge this after #131208.