Skip to content

[mlir] [TOSA] Allow any floating point type #91745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2024

Conversation

mgehre-amd
Copy link
Contributor

After #86509 allowed all integer types in TOSA ops, this PR allows TOSA ops on all floating point types.
This helps to experiment with f64 and 8-bit float types when spec conformance is not required.

@llvmbot
Copy link
Member

llvmbot commented May 10, 2024

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

After #86509 allowed all integer types in TOSA ops, this PR allows TOSA ops on all floating point types.
This helps to experiment with f64 and 8-bit float types when spec conformance is not required.


Full diff: https://github.com/llvm/llvm-project/pull/91745.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+4-17)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-5)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 97a36c49d01b3..7871b46724a03 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1857,11 +1857,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
   }];
 
   let arguments = (ins
-    Tosa_Tensor_Plus_F64:$input
+    Tosa_Tensor:$input
   );
 
   let results = (outs
-    Tosa_Tensor_Plus_F64:$output
+    Tosa_Tensor:$output
   );
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
@@ -1944,7 +1944,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 3687891fe4b7c..14fc9c7a6730c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -71,28 +71,16 @@ def Tosa_QuantizedInt	: AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
                                      Tosa_QuantizedType<"int16", [16, 0], 1>,
                                      Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
-//===----------------------------------------------------------------------===//
-// Floating-point types.
-//===----------------------------------------------------------------------===//
-def Tosa_Float : AnyTypeOf<[
-                            F32,
-			    F16,
-			    BF16]>;
-
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
-def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
+def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                "number">;
 
-// Add F64 type support just for tosa::CastOp and tosa::ConstOp
-def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
-                               "number_plus_f64">;
-
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
 // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
-                             Tosa_QuantizedInt, Tosa_Float]>;
+                             Tosa_QuantizedInt, AnyFloat]>;
 
 //===----------------------------------------------------------------------===//
 // Tensor types
@@ -101,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
 def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
 def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
 
-def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
+def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
-def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
 
 // Must be ranked but no further constraints
 def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
-                                Tosa_Float.predicate]>, "tosa.dtype">;
+                                AnyFloat.predicate]>, "tosa.dtype">;
 
 class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
   AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 539501082fd3f..b78c372af77e6 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -506,11 +506,10 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
 }
 
 bool TosaValidation::isValidElementType(Type type) {
-  if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
-    return false;
-  }
-  if (type.isF64()) {
-    return false;
+  if (isa<FloatType>(type)) {
+    if (profile == TosaProfileEnum::BaseInference)
+      return false;
+    return type.isF32() || type.isF16() || type.isBF16();
   }
   if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isUnsigned()) {
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 730ac41dd7a8d..cb38d4d81ca2e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
 // -----
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
-  // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
+  // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d8dd878051f18..9b652f2d0bd14 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -131,6 +131,14 @@ func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
 
 // -----
 
+func.func @test_const_f64(%arg0 : tensor<1xf64>) {
+  // expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}}
+  %0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64>
+  return
+}
+
+// -----
+
 func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
   %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having integer and floating-point consistent seems like a good goal.

@mgehre-amd mgehre-amd merged commit ea23897 into llvm:main May 14, 2024
@mgehre-amd mgehre-amd deleted the mgehre.tosa_all_float branch May 14, 2024 06:28
mlevesquedion pushed a commit to openxla/stablehlo that referenced this pull request May 20, 2024
Had to update the TOSA tests since all floating point types are now
supported: llvm/llvm-project#91745
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants