Skip to content

[TOSA] Allow all integer types in most ops #86509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 26, 2024

Conversation

mgehre-amd
Copy link
Contributor

As discussed in one of the previous TOSA community meetings, we would like to allow for more integer types in the TOSA dialect to enable more use cases.

For strict standards conformance, the TosaValidation pass can be used.

Follow up PRs will extend conversions from TOSA where needed.

As discussed in one of the previous TOSA community meetings,
we would like to allow for more integer types in the TOSA dialect
to enable more use cases.

For strict standards conformance, the TosaValidation pass can be used.

Follow up PRs will extend conversions from TOSA where needed.
@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-mlir-tosa

Author: Matthias Gehre (mgehre-amd)

Changes

As discussed in one of the previous TOSA community meetings, we would like to allow for more integer types in the TOSA dialect to enable more use cases.

For strict standards conformance, the TosaValidation pass can be used.

Follow up PRs will extend conversions from TOSA where needed.


Full diff: https://github.com/llvm/llvm-project/pull/86509.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+6-21)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+48-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0ecded75c5d8bc..306e4a43952088 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5a4d6ff464f19e..cff3de0a69af95 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 // Used to express accumulator results or compare results.
 //===----------------------------------------------------------------------===//
 
-def Tosa_UInt8 : UI<8>;
-def Tosa_UInt16 : UI<16>;
-
 def Tosa_Int4 : I<4>;
 def Tosa_Int8 : I<8>;
-def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
-def Tosa_Int48 : I<48>;
 def Tosa_Int64 : I<64>;
 
-def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
-                                Tosa_Int16,
-                                Tosa_Int32,
-                                Tosa_Int48,
-                                Tosa_Int64]>;
-
-def Tosa_Bool : I<1>;
-
-// No unsigned unquantized int types.
-def Tosa_Int : AnyTypeOf<[Tosa_Bool,
-                          Tosa_UInt8,
-                          Tosa_UInt16,
-                          Tosa_SignedInt]>;
+// The TOSA dialect allows more types than the TOSA standard to allow for
+// experimentation. For historical reasons, signless is used in the place of
+// signed.
+// The TosaValidation pass can be used to check for standard conformance.
+def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
+                          AnySignlessInteger]>;
 
 def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
                    	        Tosa_Int64]>;
@@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<
 
 def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
 def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
-def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
-def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
-def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
 
 //===----------------------------------------------------------------------===//
 // Attribute predicates and classes.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 967775281ad91f..b669b7362e9432 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool CheckVariable(Operation *op);
   bool CheckVariableReadOrWrite(Operation *op);
 
+  bool isValidElementType(Type type);
+
   SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
@@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
   return success();
 }
 
+bool TosaValidation::isValidElementType(Type type) {
+  if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
+    return false;
+  }
+  if (type.isF64()) {
+    return false;
+  }
+  if (auto intTy = dyn_cast<IntegerType>(type)) {
+    if (intTy.isUnsigned()) {
+      switch (intTy.getWidth()) {
+      case 8:
+      case 16:
+        return true;
+      default:
+        return false;
+      }
+    } else {
+      // Signless - treated as signed.
+      switch (intTy.getWidth()) {
+      case 1:
+      case 4:
+      case 8:
+      case 16:
+      case 32:
+      case 48:
+      case 64:
+        return true;
+      default:
+        return false;
+      }
+    }
+    return false;
+  }
+  return true;
+}
+
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
     for (Value operand : op->getOperands()) {
-      if ((profile == TosaProfileEnum::BaseInference) &&
-          isa<FloatType>(getElementTypeOrSelf(operand))) {
+      auto elementTy = getElementTypeOrSelf(operand);
+      if (!isValidElementType(elementTy)) {
+        op->emitOpError() << "failed level check: element type " << elementTy
+                          << " is not legal";
         return signalPassFailure();
       }
-      if (getElementTypeOrSelf(operand).isF64()) {
+    }
+    for (Type resultTy : op->getResultTypes()) {
+      auto elementTy = getElementTypeOrSelf(resultTy);
+      if (!isValidElementType(elementTy)) {
+        op->emitOpError() << "failed level check: element type " << elementTy
+                          << " is not legal";
         return signalPassFailure();
       }
     }
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 35ecbcc799e3df..1d3ef282836705 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
 
 // -----
 
+func.func @test_const_i2(%arg0 : tensor<1xi2>) {
+  // expected-error@+1 {{'tosa.const' op failed level check: element type 'i2' is not legal}}
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
+  return
+}
+
+// -----
+
+func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
+  // expected-error@+1 {{'tosa.const' op failed level check: element type 'ui32' is not legal}}
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
+  return
+}
+
+// -----
+
 func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
   %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

As discussed in one of the previous TOSA community meetings, we would like to allow for more integer types in the TOSA dialect to enable more use cases.

For strict standards conformance, the TosaValidation pass can be used.

Follow up PRs will extend conversions from TOSA where needed.


Full diff: https://github.com/llvm/llvm-project/pull/86509.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+6-21)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+48-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0ecded75c5d8bc..306e4a43952088 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5a4d6ff464f19e..cff3de0a69af95 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 // Used to express accumulator results or compare results.
 //===----------------------------------------------------------------------===//
 
-def Tosa_UInt8 : UI<8>;
-def Tosa_UInt16 : UI<16>;
-
 def Tosa_Int4 : I<4>;
 def Tosa_Int8 : I<8>;
-def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
-def Tosa_Int48 : I<48>;
 def Tosa_Int64 : I<64>;
 
-def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
-                                Tosa_Int16,
-                                Tosa_Int32,
-                                Tosa_Int48,
-                                Tosa_Int64]>;
-
-def Tosa_Bool : I<1>;
-
-// No unsigned unquantized int types.
-def Tosa_Int : AnyTypeOf<[Tosa_Bool,
-                          Tosa_UInt8,
-                          Tosa_UInt16,
-                          Tosa_SignedInt]>;
+// The TOSA dialect allows more types than the TOSA standard to allow for
+// experimentation. For historical reasons, signless is used in the place of
+// signed.
+// The TosaValidation pass can be used to check for standard conformance.
+def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
+                          AnySignlessInteger]>;
 
 def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
                    	        Tosa_Int64]>;
@@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<
 
 def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
 def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
-def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
-def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
-def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
 
 //===----------------------------------------------------------------------===//
 // Attribute predicates and classes.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 967775281ad91f..b669b7362e9432 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool CheckVariable(Operation *op);
   bool CheckVariableReadOrWrite(Operation *op);
 
+  bool isValidElementType(Type type);
+
   SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
@@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
   return success();
 }
 
+bool TosaValidation::isValidElementType(Type type) {
+  if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
+    return false;
+  }
+  if (type.isF64()) {
+    return false;
+  }
+  if (auto intTy = dyn_cast<IntegerType>(type)) {
+    if (intTy.isUnsigned()) {
+      switch (intTy.getWidth()) {
+      case 8:
+      case 16:
+        return true;
+      default:
+        return false;
+      }
+    } else {
+      // Signless - treated as signed.
+      switch (intTy.getWidth()) {
+      case 1:
+      case 4:
+      case 8:
+      case 16:
+      case 32:
+      case 48:
+      case 64:
+        return true;
+      default:
+        return false;
+      }
+    }
+    return false;
+  }
+  return true;
+}
+
 void TosaValidation::runOnOperation() {
   configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
     for (Value operand : op->getOperands()) {
-      if ((profile == TosaProfileEnum::BaseInference) &&
-          isa<FloatType>(getElementTypeOrSelf(operand))) {
+      auto elementTy = getElementTypeOrSelf(operand);
+      if (!isValidElementType(elementTy)) {
+        op->emitOpError() << "failed level check: element type " << elementTy
+                          << " is not legal";
         return signalPassFailure();
       }
-      if (getElementTypeOrSelf(operand).isF64()) {
+    }
+    for (Type resultTy : op->getResultTypes()) {
+      auto elementTy = getElementTypeOrSelf(resultTy);
+      if (!isValidElementType(elementTy)) {
+        op->emitOpError() << "failed level check: element type " << elementTy
+                          << " is not legal";
         return signalPassFailure();
       }
     }
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 35ecbcc799e3df..1d3ef282836705 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
 
 // -----
 
+func.func @test_const_i2(%arg0 : tensor<1xi2>) {
+  // expected-error@+1 {{'tosa.const' op failed level check: element type 'i2' is not legal}}
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
+  return
+}
+
+// -----
+
+func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
+  // expected-error@+1 {{'tosa.const' op failed level check: element type 'ui32' is not legal}}
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
+  return
+}
+
+// -----
+
 func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
   %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

@eric-k256
Copy link
Contributor

Adding @sjarus for additional review. At a first read, this looks okay. We've been working on some improvements for the validation pass to do a more detailed per-op checking of types, but that can come later. Ideally we will get to the point where we can create some of the validation from the TOSA spec xml.

@mgehre-amd mgehre-amd merged commit c6d419c into llvm:main Mar 26, 2024
@mgehre-amd mgehre-amd deleted the matthias.tosa_all_int branch March 26, 2024 21:27
mgehre-amd added a commit to Xilinx/llvm-project that referenced this pull request May 6, 2024
As discussed in one of the previous TOSA community meetings, we would
like to allow for more integer types in the TOSA dialect to enable more
use cases.

For strict standards conformance, the TosaValidation pass can be used.

Follow up PRs will extend conversions from TOSA where needed.
mgehre-amd added a commit that referenced this pull request May 14, 2024
After #86509 allowed all integer types in TOSA ops, this PR allows TOSA
ops on all floating point types.
This helps to experiment with `f64` and 8-bit float types when spec
conformance is not required.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants