Skip to content

Fix TOSA FP16->INT16 CAST lowering #79299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 30, 2024
Merged

Fix TOSA FP16->INT16 CAST lowering #79299

merged 3 commits into from
Jan 30, 2024

Conversation

RoboTux
Copy link
Contributor

@RoboTux RoboTux commented Jan 24, 2024

Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits.

This patch instead use a select acting on a compare against max int + 1
which is representable in floating-point. It also has a special lowering
for cases where the integer range is wider than the floating-point range
to clamp the infinite values.

Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits. This patch
instead use a select acting on a compare against max int + 1 which is
representable in floating-point.
@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Thomas Preud'homme (RoboTux)

Changes

Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits. This patch
instead use a select acting on a compare against max int + 1 which is
representable in floating-point.


Full diff: https://github.com/llvm/llvm-project/pull/79299.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+38-8)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+14-12)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c8760..96de43caae7364 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,53 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
-      auto intMin = rewriter.create<arith::ConstantOp>(
+      auto intMinFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+      // The input floating-point type has enough mantissa bits to represent
+      // the max int value so just clamp the input in the floating-point
+      // domain and convert to int. Note: the min value can be represented
+      // because it consists of a mantissa with only the lsb set.
+      if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+          dstTy.getIntOrFloatBitWidth() - 1) {
+        auto intMaxFP = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getFloatAttr(
+                     getElementTypeOrSelf(srcTy),
+                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+                         .getSExtValue()));
+
+        auto clamped =
+            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+        return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+      }
+
+      // Otherwise, we can rely on int max + 1 being representable because it
+      // also consists of a single lsb set in the mantissa. So clamp the min
+      // value and compare against that to select the max int value if needed.
+      auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getFloatAttr(
                    getElementTypeOrSelf(srcTy),
                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
-                       .getSExtValue()));
-
-      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+                           .getSExtValue() +
+                       1));
 
-      auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
-
-      return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+      auto intMax = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(
+                   getElementTypeOrSelf(dstTy),
+                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+      auto minClampedFP =
+          rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+      auto minClamped =
+          rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+      auto overflow = rewriter.create<arith::CmpFOp>(
+          loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+                                              minClamped);
     }
 
     // Casting to boolean, integers need to only be checked as not-equal to
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c8b..b19f9a04bd6f3b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant -2.14748365E+9
-  // CHECK: arith.constant 2.14748365E+9
-  // CHECK: math.roundeven
-  // CHECK: arith.minimumf
-  // CHECK: arith.maximumf
-  // CHECK: arith.fptosi
+  // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
+  // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
+  // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
+  // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
+  // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -552,12 +554,12 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant -1.280000e+02
-  // CHECK: arith.constant 1.270000e+02
-  // CHECK: math.roundeven
-  // CHECK: arith.minimumf
-  // CHECK: arith.maximumf
-  // CHECK: arith.fptosi
+  // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
+  // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+  // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
+  // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+  // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
   return
 }

@RoboTux
Copy link
Contributor Author

RoboTux commented Jan 24, 2024

I'll wait til tomorrow to merge to leave time for others to comment.

@banach-space
Copy link
Contributor

I'll wait til tomorrow to merge to leave time for others to comment.

You may want to add more reviewers ;-)

@RoboTux RoboTux requested review from sabauma and eric-k256 January 24, 2024 16:28
Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@GeorgeARM GeorgeARM requested a review from kuhar January 24, 2024 18:57
Copy link
Contributor

@sabauma sabauma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@RoboTux
Copy link
Contributor Author

RoboTux commented Jan 29, 2024

@eric-k256 @sabauma I've added a whole case to deal with widening cast (where infinites need to be handled) and reworded the comments and commit message. I'd appreciate a new review.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just noticed one thing: this function is 500 lines long and this PR makes it longer. Does this have to be structured like this? I don't have any stake in this code, but I'm concerned about the maintainability.

@RoboTux
Copy link
Contributor Author

RoboTux commented Jan 30, 2024

I've just noticed one thing: this function is 500 lines long and this PR makes it longer. Does this have to be structured like this? I don't have any stake in this code, but I'm concerned about the maintainability.

I'm happy to fix that in a follow-up patch. Doing it in the same patch would make it harder to read IMO.

@RoboTux RoboTux merged commit b23e518 into llvm:main Jan 30, 2024
@RoboTux RoboTux deleted the fix_tosa_cast branch April 18, 2024 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants