Skip to content

Commit b23e518

Browse files
authored
Fix TOSA FP16->INT16 CAST lowering (#79299)
Currently cast from FP to int is implemented by clamping on the min and max integer values in the floating-point domain and then converting to integer. However, the max int values are often non representable in the floating-point input type due to lack of mantissa bits. This patch instead use a select acting on a compare against max int + 1 which is representable in floating-point. It also has a special lowering for cases where the integer range is wider than the floating-point range to clamp the infinite values.
1 parent 8a5bdd8 commit b23e518

File tree

2 files changed

+100
-20
lines changed

2 files changed

+100
-20
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -480,23 +480,88 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
480480
}
481481

482482
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
483-
auto intMin = rewriter.create<arith::ConstantOp>(
483+
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
484+
485+
const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
486+
// Check whether neither int min nor int max can be represented in the
487+
// input floating-point type due to too short exponent range.
488+
if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
489+
APFloat::semanticsMaxExponent(fltSemantics)) {
490+
// Use cmp + select to replace infinites by int min / int max. Other
491+
// integral values can be represented in the integer space.
492+
auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
493+
auto posInf = rewriter.create<arith::ConstantOp>(
494+
loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
495+
APFloat::getInf(fltSemantics)));
496+
auto negInf = rewriter.create<arith::ConstantOp>(
497+
loc, rewriter.getFloatAttr(
498+
getElementTypeOrSelf(srcTy),
499+
APFloat::getInf(fltSemantics, /*Negative=*/true)));
500+
auto overflow = rewriter.create<arith::CmpFOp>(
501+
loc, arith::CmpFPredicate::UEQ, rounded, posInf);
502+
auto underflow = rewriter.create<arith::CmpFOp>(
503+
loc, arith::CmpFPredicate::UEQ, rounded, negInf);
504+
auto intMin = rewriter.create<arith::ConstantOp>(
505+
loc, rewriter.getIntegerAttr(
506+
getElementTypeOrSelf(dstTy),
507+
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
508+
auto intMax = rewriter.create<arith::ConstantOp>(
509+
loc, rewriter.getIntegerAttr(
510+
getElementTypeOrSelf(dstTy),
511+
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
512+
auto maxClamped =
513+
rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
514+
return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
515+
maxClamped);
516+
}
517+
518+
auto intMinFP = rewriter.create<arith::ConstantOp>(
484519
loc, rewriter.getFloatAttr(
485520
getElementTypeOrSelf(srcTy),
486521
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
487522
.getSExtValue()));
488523

489-
auto intMax = rewriter.create<arith::ConstantOp>(
524+
// Check whether the mantissa has enough bits to represent int max.
525+
if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
526+
dstTy.getIntOrFloatBitWidth() - 1) {
527+
// Int min can also be represented since it is a power of two and thus
528+
// consists of a single leading bit. Therefore we can clamp the input
529+
// in the floating-point domain.
530+
531+
auto intMaxFP = rewriter.create<arith::ConstantOp>(
532+
loc, rewriter.getFloatAttr(
533+
getElementTypeOrSelf(srcTy),
534+
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
535+
.getSExtValue()));
536+
537+
Value clamped =
538+
clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
539+
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
540+
}
541+
542+
// Due to earlier check we know exponant range is big enough to represent
543+
// int min. We can therefore rely on int max + 1 being representable as
544+
// well because it's just int min with a positive sign. So clamp the min
545+
// value and compare against that to select the max int value if needed.
546+
auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
490547
loc, rewriter.getFloatAttr(
491548
getElementTypeOrSelf(srcTy),
492549
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
493-
.getSExtValue()));
494-
495-
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
496-
497-
auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
550+
.getSExtValue() +
551+
1));
498552

499-
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
553+
auto intMax = rewriter.create<arith::ConstantOp>(
554+
loc, rewriter.getIntegerAttr(
555+
getElementTypeOrSelf(dstTy),
556+
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
557+
auto minClampedFP =
558+
rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
559+
auto minClamped =
560+
rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
561+
auto overflow = rewriter.create<arith::CmpFOp>(
562+
loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
563+
return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
564+
minClamped);
500565
}
501566

502567
// Casting to boolean, integers need to only be checked as not-equal to

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,14 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
514514
%19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
515515

516516
// CHECK: linalg.generic
517-
// CHECK: arith.constant -2.14748365E+9
518-
// CHECK: arith.constant 2.14748365E+9
519-
// CHECK: math.roundeven
520-
// CHECK: arith.minimumf
521-
// CHECK: arith.maximumf
522-
// CHECK: arith.fptosi
517+
// CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
518+
// CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
519+
// CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
520+
// CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
521+
// CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
522+
// CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
523+
// CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
524+
// CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
523525
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
524526

525527
// CHECK: linalg.generic
@@ -552,13 +554,26 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
552554
%0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
553555

554556
// CHECK: linalg.generic
555-
// CHECK: arith.constant -1.280000e+02
556-
// CHECK: arith.constant 1.270000e+02
557-
// CHECK: math.roundeven
558-
// CHECK: arith.minimumf
559-
// CHECK: arith.maximumf
560-
// CHECK: arith.fptosi
557+
// CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
558+
// CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
559+
// CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
560+
// CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
561+
// CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
562+
// CHECK: arith.fptosi [[CLAMP]] : f16 to i8
561563
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
564+
565+
// CHECK: linalg.generic
566+
// CHECK: [[ROUND:%.+]] = math.roundeven {{%[a-z0-9_]+}} : f16
567+
// CHECK: [[CONV:%.+]] = arith.fptosi [[ROUND]] : f16 to i32
568+
// CHECK: [[POSINF:%.+]] = arith.constant 0x7C00 : f16
569+
// CHECK: [[NEGINF:%.+]] = arith.constant 0xFC00 : f16
570+
// CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
571+
// CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
572+
// CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
573+
// CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
574+
// CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
575+
// CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
576+
%2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
562577
return
563578
}
564579

0 commit comments

Comments
 (0)