@@ -480,23 +480,88 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
480
480
}
481
481
482
482
if (arith::FPToSIOp::areCastCompatible (srcTy, dstTy)) {
483
- auto intMin = rewriter.create <arith::ConstantOp>(
483
+ auto rounded = rewriter.create <math::RoundEvenOp>(loc, args[0 ]);
484
+
485
+ const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics ();
486
+ // Check whether neither int min nor int max can be represented in the
487
+ // input floating-point type due to too short exponent range.
488
+ if (static_cast <int >(dstTy.getIntOrFloatBitWidth ()) - 1 >
489
+ APFloat::semanticsMaxExponent (fltSemantics)) {
490
+ // Use cmp + select to replace infinites by int min / int max. Other
491
+ // integral values can be represented in the integer space.
492
+ auto conv = rewriter.create <arith::FPToSIOp>(loc, dstTy, rounded);
493
+ auto posInf = rewriter.create <arith::ConstantOp>(
494
+ loc, rewriter.getFloatAttr (getElementTypeOrSelf (srcTy),
495
+ APFloat::getInf (fltSemantics)));
496
+ auto negInf = rewriter.create <arith::ConstantOp>(
497
+ loc, rewriter.getFloatAttr (
498
+ getElementTypeOrSelf (srcTy),
499
+ APFloat::getInf (fltSemantics, /* Negative=*/ true )));
500
+ auto overflow = rewriter.create <arith::CmpFOp>(
501
+ loc, arith::CmpFPredicate::UEQ, rounded, posInf);
502
+ auto underflow = rewriter.create <arith::CmpFOp>(
503
+ loc, arith::CmpFPredicate::UEQ, rounded, negInf);
504
+ auto intMin = rewriter.create <arith::ConstantOp>(
505
+ loc, rewriter.getIntegerAttr (
506
+ getElementTypeOrSelf (dstTy),
507
+ APInt::getSignedMinValue (dstTy.getIntOrFloatBitWidth ())));
508
+ auto intMax = rewriter.create <arith::ConstantOp>(
509
+ loc, rewriter.getIntegerAttr (
510
+ getElementTypeOrSelf (dstTy),
511
+ APInt::getSignedMaxValue (dstTy.getIntOrFloatBitWidth ())));
512
+ auto maxClamped =
513
+ rewriter.create <arith::SelectOp>(loc, overflow, intMax, conv);
514
+ return rewriter.create <arith::SelectOp>(loc, underflow, intMin,
515
+ maxClamped);
516
+ }
517
+
518
+ auto intMinFP = rewriter.create <arith::ConstantOp>(
484
519
loc, rewriter.getFloatAttr (
485
520
getElementTypeOrSelf (srcTy),
486
521
APInt::getSignedMinValue (dstTy.getIntOrFloatBitWidth ())
487
522
.getSExtValue ()));
488
523
489
- auto intMax = rewriter.create <arith::ConstantOp>(
524
+ // Check whether the mantissa has enough bits to represent int max.
525
+ if (cast<FloatType>(srcTy).getFPMantissaWidth () >=
526
+ dstTy.getIntOrFloatBitWidth () - 1 ) {
527
+ // Int min can also be represented since it is a power of two and thus
528
+ // consists of a single leading bit. Therefore we can clamp the input
529
+ // in the floating-point domain.
530
+
531
+ auto intMaxFP = rewriter.create <arith::ConstantOp>(
532
+ loc, rewriter.getFloatAttr (
533
+ getElementTypeOrSelf (srcTy),
534
+ APInt::getSignedMaxValue (dstTy.getIntOrFloatBitWidth ())
535
+ .getSExtValue ()));
536
+
537
+ Value clamped =
538
+ clampFloatHelper (loc, rounded, intMinFP, intMaxFP, rewriter);
539
+ return rewriter.create <arith::FPToSIOp>(loc, dstTy, clamped);
540
+ }
541
+
542
+ // Due to earlier check we know exponant range is big enough to represent
543
+ // int min. We can therefore rely on int max + 1 being representable as
544
+ // well because it's just int min with a positive sign. So clamp the min
545
+ // value and compare against that to select the max int value if needed.
546
+ auto intMaxPlusOneFP = rewriter.create <arith::ConstantOp>(
490
547
loc, rewriter.getFloatAttr (
491
548
getElementTypeOrSelf (srcTy),
492
549
APInt::getSignedMaxValue (dstTy.getIntOrFloatBitWidth ())
493
- .getSExtValue ()));
494
-
495
- auto rounded = rewriter.create <math::RoundEvenOp>(loc, args[0 ]);
496
-
497
- auto clamped = clampFloatHelper (loc, rounded, intMin, intMax, rewriter);
550
+ .getSExtValue () +
551
+ 1 ));
498
552
499
- return rewriter.create <arith::FPToSIOp>(loc, dstTy, clamped);
553
+ auto intMax = rewriter.create <arith::ConstantOp>(
554
+ loc, rewriter.getIntegerAttr (
555
+ getElementTypeOrSelf (dstTy),
556
+ APInt::getSignedMaxValue (dstTy.getIntOrFloatBitWidth ())));
557
+ auto minClampedFP =
558
+ rewriter.create <arith::MaximumFOp>(loc, rounded, intMinFP);
559
+ auto minClamped =
560
+ rewriter.create <arith::FPToSIOp>(loc, dstTy, minClampedFP);
561
+ auto overflow = rewriter.create <arith::CmpFOp>(
562
+ loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
563
+ return rewriter.create <arith::SelectOp>(loc, overflow, intMax,
564
+ minClamped);
500
565
}
501
566
502
567
// Casting to boolean, integers need to only be checked as not-equal to
0 commit comments