-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Fix TOSA FP16->INT16 CAST lowering #79299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently cast from FP to int is implemented by clamping on the min and max integer values in the floating-point domain and then converting to integer. However, the max int values are often non representable in the floating-point input type due to lack of mantissa bits. This patch instead use a select acting on a compare against max int + 1 which is representable in floating-point.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Thomas Preud'homme (RoboTux) ChangesCurrently cast from FP to int is implemented by clamping on the min and max Full diff: https://github.com/llvm/llvm-project/pull/79299.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 647592395c8760..96de43caae7364 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -480,23 +480,53 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
- auto intMin = rewriter.create<arith::ConstantOp>(
+ auto intMinFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
- auto intMax = rewriter.create<arith::ConstantOp>(
+ auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+
+ // The input floating-point type has enough mantissa bits to represent
+ // the max int value so just clamp the input in the floating-point
+ // domain and convert to int. Note: the min value can be represented
+ // because it consists of a mantissa with only the lsb set.
+ if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
+ dstTy.getIntOrFloatBitWidth() - 1) {
+ auto intMaxFP = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(
+ getElementTypeOrSelf(srcTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
+
+ auto clamped =
+ clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+ return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ }
+
+ // Otherwise, we can rely on int max + 1 being representable because it
+ // also consists of a single lsb set in the mantissa. So clamp the min
+ // value and compare against that to select the max int value if needed.
+ auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
- .getSExtValue()));
-
- auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
+ .getSExtValue() +
+ 1));
- auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
-
- return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
+ auto intMax = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ getElementTypeOrSelf(dstTy),
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+ auto minClampedFP =
+ rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
+ auto minClamped =
+ rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
+ auto overflow = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
+ return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+ minClamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f63b7d5ca6c8b..b19f9a04bd6f3b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -2.14748365E+9
- // CHECK: arith.constant 2.14748365E+9
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -2.14748365E+9 : f32
+ // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f32
+ // CHECK: [[CSTMAXP1:%[a-z0-9_]+]] = arith.constant 2.14748365E+9 : f32
+ // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 2147483647 : i32
+ // CHECK: [[MAX:%[a-z0-9_]+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+ // CHECK: [[CONV:%[a-z0-9_]+]] = arith.fptosi [[MAX]] : f32 to i32
+ // CHECK: [[CMP:%[a-z0-9_]+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
+ // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
// CHECK: linalg.generic
@@ -552,12 +554,12 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.constant -1.280000e+02
- // CHECK: arith.constant 1.270000e+02
- // CHECK: math.roundeven
- // CHECK: arith.minimumf
- // CHECK: arith.maximumf
- // CHECK: arith.fptosi
+ // CHECK: [[CSTMIN:%[a-z0-9_]+]] = arith.constant -1.280000e+02 : f16
+ // CHECK: [[ROUND:%[a-z0-9_]+]] = math.roundeven {{%[a-z0-9_]+}} : f16
+ // CHECK: [[CSTMAX:%[a-z0-9_]+]] = arith.constant 1.270000e+02 : f16
+ // CHECK: [[MIN:%[a-z0-9_]+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+ // CHECK: [[CLAMP:%[a-z0-9_]+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
+ // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
return
}
|
I'll wait til tomorrow to merge to leave time for others to comment. |
You may want to add more reviewers ;-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@eric-k256 @sabauma I've added a whole case to deal with widening cast (where infinites need to be handled) and reworded the comments and commit message. I'd appreciate a new review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just noticed one thing: this function is 500 lines long and this PR makes it longer. Does this have to be structured like this? I don't have any stake in this code, but I'm concerned about the maintainability.
I'm happy to fix that in a follow-up patch. Doing it in the same patch would make it harder to read IMO. |
Currently cast from FP to int is implemented by clamping on the min and max
integer values in the floating-point domain and then converting to
integer. However, the max int values are often non representable in the
floating-point input type due to lack of mantissa bits.
This patch instead use a select acting on a compare against max int + 1
which is representable in floating-point. It also has a special lowering
for cases where the integer range is wider than the floating-point range
to clamp the infinite values.