Skip to content

[mlir][tosa] Disallow invalid datatype combinations in the validation pass #131595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Mar 17, 2025

This commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure.

Note: We should aim to merge this after #131208.

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure.

Note: We should aim to merge this after #131208.


Full diff: https://github.com/llvm/llvm-project/pull/131595.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+30-11)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+5)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+19)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 69b827fe14dee..da187d8316989 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
   // environment.
   LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
   LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+  LogicalResult checkInvalid(Operation *op);
 
   template <typename T>
   LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
   stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
 
 private:
+  template <typename T>
+  FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
+                                                  CheckCondition &condition);
+
   OperationProfileComplianceMap profileComplianceMap;
   OperationExtensionComplianceMap extensionComplianceMap;
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ed2c40598458c..9523146581f10 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -300,6 +300,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 // Tosa Profile And Extension Compliance Checker
 //===----------------------------------------------------------------------===//
 
+template <typename T>
+FailureOr<SmallVector<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op,
+                                             CheckCondition &condition) {
+  const std::string opName = op->getName().getStringRef().str();
+  const auto complianceMap = getProfileComplianceMap<T>();
+  const auto it = complianceMap.find(opName);
+  if (it == complianceMap.end())
+    return {};
+
+  return findMatchedProfile<T>(op, it->second, condition);
+}
+
 template <typename T>
 LogicalResult TosaProfileCompliance::checkProfileOrExtension(
     Operation *op, const tosa::TargetEnv &targetEnv,
@@ -309,11 +322,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
   if (specRequiredModeSet.size() == 0)
     return success();
 
-  auto opName = op->getName().getStringRef().str();
-  auto compMap = getProfileComplianceMap<T>();
-  auto it = compMap.find(opName);
-
-  if (it == compMap.end()) {
+  CheckCondition condition = CheckCondition::invalid;
+  const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
+  if (failed(maybeOpRequiredMode)) {
     // Operators such as variable and shape ops do not have an operand type
     // restriction. When the profile compliance information of operation is not
     // found, confirm if the target have enabled the profile required from the
@@ -334,12 +345,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
     return failure();
   }
 
-  CheckCondition condition = CheckCondition::invalid;
-  // Find the profiles or extensions requirement according to the signature of
-  // type of the operand list.
-  SmallVector<T> opRequiredMode =
-      findMatchedProfile<T>(op, it->second, condition);
-
+  // Find the required profiles or extensions according to the operand type
+  // combination.
+  const auto opRequiredMode = maybeOpRequiredMode.value();
   if (opRequiredMode.size() == 0) {
     // No matched restriction found.
     return success();
@@ -419,6 +427,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
   return success();
 }
 
+LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
+  CheckCondition condition = CheckCondition::invalid;
+  const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
+  const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
+      !maybeProfDef.value().size() && !maybeExtDef.value().size())
+    return failure();
+
+  return success();
+}
+
 // Find the profiles or extensions requirement according to the signature of
 // type of the operand list.
 template <typename T>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 79c13793d7713..7604a33218a88 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1042,6 +1042,11 @@ void TosaValidation::runOnOperation() {
       }
     }
 
+    if (failed(profileComp.checkInvalid(op))) {
+      op->emitOpError("illegal: operand/result data types not supported");
+      return signalPassFailure();
+    }
+
     // Some uses of TOSA rely on the constant operands of particular
     // operations.
     if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index ecd5c792e08b6..22b07e69d3b87 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -14,10 +14,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
 // -----
 
 // check that -tosa-validate level checking kick in
-func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
   // expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}}
-  %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
-  return %0 : tensor<*xi8>
+  %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index fd9b3d5f23483..1598aefd95823 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -6,9 +6,9 @@
 
 // -----
 
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %0 : tensor<13x21x3xi8>
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5b591e3c5f45c..5509b28a2f6d1 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1852,3 +1852,22 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
          (tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
   return %0 : tensor<1x32x2x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_add_i1
+func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+  // expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_out_i16
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+  // expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 13952716a9611..10140cc0a1e9b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
 
 // -----
 
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
   // expected-error@+1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %0 : tensor<13x21x3xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bdf18ec823128..0f469761d89e3 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
 
 // -----
 
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
+func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
   // expected-error@+1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
-    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
-    return %0 : tensor<1x1x1x1x1x1x64xi16>
+    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
+    return %0 : tensor<1x1x1x1x1x1x64xi32>
 }
 
 // -----

@lhutton1 lhutton1 force-pushed the disallow-invalid-dtype-combinations branch 2 times, most recently from 3bb61ab to edc7f32 Compare March 19, 2025 15:01
… pass

This commit checks if the operands/results of an operator can be found
in the profile compliance mapping, if it isn't the operator is considered
invalid. As a result, operator datatype combinations that are not listed
under the "Supported Data Types" of the TOSA specification are disallowed
and the validation pass results in failure.

Signed-off-by: Luke Hutton <[email protected]>
Change-Id: Iab36dd84cdbf188015c80b066c321edbb2efc0ff
@lhutton1 lhutton1 force-pushed the disallow-invalid-dtype-combinations branch from edc7f32 to f891830 Compare March 21, 2025 11:49
@lhutton1
Copy link
Contributor Author

Friendly ping for review

@lhutton1 lhutton1 merged commit d4570ea into llvm:main Mar 25, 2025
11 checks passed
@lhutton1 lhutton1 deleted the disallow-invalid-dtype-combinations branch March 28, 2025 15:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants