Skip to content

Commit 50df541

Browse files
committed
[InstCombine] Add folds for (fp_binop ({s|u}itofp x), ({s|u}itofp y))
The full fold is one of the following: 1) `(fp_binop ({s|u}itofp x), ({s|u}itofp y))` -> `({s|u}itofp (int_binop x, y))` 2) `(fp_binop ({s|u}itofp x), FpC)` -> `({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))` And support the following binops: `fmul` -> `mul` `fadd` -> `add` `fsub` -> `sub` Proofs: https://alive2.llvm.org/ce/z/zuacA8 The proofs timeout, so they must be reproduced locally.
1 parent 8036d39 commit 50df541

File tree

6 files changed

+223
-148
lines changed

6 files changed

+223
-148
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,6 +2778,9 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
27782778
if (Instruction *X = foldFNegIntoConstant(I, DL))
27792779
return X;
27802780

2781+
if (Instruction *R = foldFBinOpOfIntCasts(I))
2782+
return R;
2783+
27812784
Value *X, *Y;
27822785
Constant *C;
27832786

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
769769
if (Instruction *R = foldFPSignBitOps(I))
770770
return R;
771771

772+
if (Instruction *R = foldFBinOpOfIntCasts(I))
773+
return R;
774+
772775
// X * -1.0 --> -X
773776
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
774777
if (match(Op1, m_SpecificFP(-1.0)))

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 157 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,71 +1402,168 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
14021402
}
14031403

14041404
// Try to fold:
1405-
// 1) (add (sitofp x), (sitofp y))
1406-
// -> (sitofp (add x, y))
1407-
// 2) (add (sitofp x), FpC)
1408-
// -> (sitofp (add x, (fptosi FpC)))
1405+
// 1) (fp_binop ({s|u}itofp x), ({s|u}itofp y))
1406+
// -> ({s|u}itofp (int_binop x, y))
1407+
// 2) (fp_binop ({s|u}itofp x), FpC)
1408+
// -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))
14091409
Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) {
1410-
// Check for (fadd double (sitofp x), y), see if we can merge this into an
1411-
// integer add followed by a promotion.
1412-
Value *LHS = BO.getOperand(0), *RHS = BO.getOperand(1);
1413-
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
1414-
Value *LHSIntVal = LHSConv->getOperand(0);
1415-
Type *FPType = LHSConv->getType();
1416-
1417-
// TODO: This check is overly conservative. In many cases known bits
1418-
// analysis can tell us that the result of the addition has less significant
1419-
// bits than the integer type can hold.
1420-
auto IsValidPromotion = [](Type *FTy, Type *ITy) {
1421-
Type *FScalarTy = FTy->getScalarType();
1422-
Type *IScalarTy = ITy->getScalarType();
1423-
1424-
// Do we have enough bits in the significand to represent the result of
1425-
// the integer addition?
1426-
unsigned MaxRepresentableBits =
1427-
APFloat::semanticsPrecision(FScalarTy->getFltSemantics());
1428-
return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits;
1429-
};
1410+
Value *IntOps[2];
1411+
Constant *Op1FpC = nullptr;
1412+
1413+
// Check for:
1414+
// 1) (binop ({s|u}itofp x), ({s|u}itofp y))
1415+
// 2) (binop ({s|u}itofp x), FpC)
1416+
if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) &&
1417+
!match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0]))))
1418+
return nullptr;
14301419

1431-
// (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
1432-
// ... if the constant fits in the integer value. This is useful for things
1433-
// like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
1434-
// requires a constant pool load, and generally allows the add to be better
1435-
// instcombined.
1436-
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
1437-
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1438-
Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP,
1439-
LHSIntVal->getType(), DL);
1440-
if (LHSConv->hasOneUse() &&
1441-
ConstantFoldCastOperand(Instruction::SIToFP, CI, BO.getType(),
1442-
DL) == CFP &&
1443-
willNotOverflowSignedAdd(LHSIntVal, CI, BO)) {
1444-
// Insert the new integer add.
1445-
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI);
1446-
return new SIToFPInst(NewAdd, BO.getType());
1447-
}
1448-
}
1420+
if (!match(BO.getOperand(1), m_Constant(Op1FpC)) &&
1421+
!match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) &&
1422+
!match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1]))))
1423+
return nullptr;
14491424

1450-
// (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
1451-
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
1452-
Value *RHSIntVal = RHSConv->getOperand(0);
1453-
// It's enough to check LHS types only because we require int types to
1454-
// be the same for this transform.
1455-
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1456-
// Only do this if x/y have the same type, if at least one of them has a
1457-
// single use (so we don't increase the number of int->fp conversions),
1458-
// and if the integer add will not overflow.
1459-
if (LHSIntVal->getType() == RHSIntVal->getType() &&
1460-
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1461-
willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, BO)) {
1462-
// Insert the new integer add.
1463-
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal);
1464-
return new SIToFPInst(NewAdd, BO.getType());
1465-
}
1466-
}
1425+
1426+
Type *FPTy = BO.getType();
1427+
Type *IntTy = IntOps[0]->getType();
1428+
1429+
// Do we have signed casts?
1430+
bool OpsFromSigned = isa<SIToFPInst>(BO.getOperand(0));
1431+
1432+
1433+
unsigned IntSz = IntTy->getScalarSizeInBits();
1434+
// This is the maximum number of inuse bits by the integer where the int -> fp
1435+
// casts are exact.
1436+
unsigned MaxRepresentableBits =
1437+
APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics());
1438+
1439+
// Cache KnownBits a bit to potentially save some analysis.
1440+
std::optional<KnownBits> OpsKnown[2];
1441+
1442+
// Preserve known number of leading bits. This can allow us to trivial nsw/nuw
1443+
// checks later on.
1444+
unsigned NumUsedLeadingBits[2] = {IntSz, IntSz};
1445+
1446+
auto IsNonZero = [&](unsigned OpNo) -> bool {
1447+
if (OpsKnown[OpNo].has_value() && OpsKnown[OpNo]->isNonZero())
1448+
return true;
1449+
return isKnownNonZero(IntOps[OpNo], SQ.DL);
1450+
};
1451+
1452+
auto IsNonNeg = [&](unsigned OpNo) -> bool {
1453+
if (OpsKnown[OpNo].has_value() && OpsKnown[OpNo]->isNonNegative())
1454+
return true;
1455+
return isKnownNonNegative(IntOps[OpNo], SQ);
1456+
};
1457+
1458+
// Check if we know for certain that ({s|u}itofp op) is exact.
1459+
auto IsValidPromotion = [&](unsigned OpNo) -> bool {
1460+
// If fp precision >= bitwidth(op) then its exact.
1461+
if (MaxRepresentableBits >= IntSz)
1462+
;
1463+
// Otherwise if its signed cast check that fp precisions >= bitwidth(op) -
1464+
// numSignBits(op).
1465+
else if (OpsFromSigned)
1466+
NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]);
1467+
// Finally for unsigned check that fp precision >= bitwidth(op) -
1468+
// numLeadingZeros(op).
1469+
else {
1470+
if (!OpsKnown[OpNo].has_value())
1471+
OpsKnown[OpNo] = computeKnownBits(IntOps[OpNo], /*Depth*/ 0, &BO);
1472+
NumUsedLeadingBits[OpNo] = IntSz - OpsKnown[OpNo]->countMinLeadingZeros();
1473+
}
1474+
// NB: We could also check if op is known to be a power of 2 or zero (which
1475+
// will always be representable). Its unlikely, however, that is we are
1476+
// unable to bound op in any way we will be able to pass the overflow checks
1477+
// later on.
1478+
1479+
if (MaxRepresentableBits < NumUsedLeadingBits[OpNo])
1480+
return false;
1481+
// Signed + Mul also requires that op is non-zero to avoid -0 cases.
1482+
return (OpsFromSigned && BO.getOpcode() == Instruction::FMul)
1483+
? IsNonZero(OpNo)
1484+
: true;
1485+
1486+
};
1487+
1488+
// If we have a constant rhs, see if we can losslessly convert it to an int.
1489+
if (Op1FpC != nullptr) {
1490+
Constant *Op1IntC = ConstantFoldCastOperand(
1491+
OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC,
1492+
IntTy, DL);
1493+
if (Op1IntC == nullptr)
1494+
return nullptr;
1495+
if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP
1496+
: Instruction::UIToFP,
1497+
Op1IntC, FPTy, DL) != Op1FpC)
1498+
return nullptr;
1499+
1500+
// First try to keep sign of cast the same.
1501+
IntOps[1] = Op1IntC;
1502+
}
1503+
1504+
// Ensure lhs/rhs integer types match.
1505+
if (IntTy != IntOps[1]->getType())
1506+
return nullptr;
1507+
1508+
1509+
if (Op1FpC == nullptr) {
1510+
if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(1))) {
1511+
// If we have a signed + unsigned, see if we can treat both as signed
1512+
// (uitofp nneg x) == (sitofp nneg x).
1513+
if (OpsFromSigned ? !IsNonNeg(1) : !IsNonNeg(0))
1514+
return nullptr;
1515+
OpsFromSigned = true;
14671516
}
1517+
if (!IsValidPromotion(1))
1518+
return nullptr;
14681519
}
1469-
return nullptr;
1520+
if (!IsValidPromotion(0))
1521+
return nullptr;
1522+
1523+
// Final we check if the integer version of the binop will not overflow.
1524+
BinaryOperator::BinaryOps IntOpc;
1525+
// Because of the precision check, we can often rule out overflows.
1526+
bool NeedsOverflowCheck = true;
1527+
// Try to conservatively rule out overflow based on the already done precision
1528+
// checks.
1529+
unsigned OverflowMaxOutputBits = OpsFromSigned ? 2 : 1;
1530+
unsigned OverflowMaxCurBits =
1531+
std::max(NumUsedLeadingBits[0], NumUsedLeadingBits[1]);
1532+
bool OutputSigned = OpsFromSigned;
1533+
switch (BO.getOpcode()) {
1534+
case Instruction::FAdd:
1535+
IntOpc = Instruction::Add;
1536+
OverflowMaxOutputBits += OverflowMaxCurBits;
1537+
break;
1538+
case Instruction::FSub:
1539+
IntOpc = Instruction::Sub;
1540+
OverflowMaxOutputBits += OverflowMaxCurBits;
1541+
break;
1542+
case Instruction::FMul:
1543+
IntOpc = Instruction::Mul;
1544+
OverflowMaxOutputBits += OverflowMaxCurBits * 2;
1545+
break;
1546+
default:
1547+
llvm_unreachable("Unsupported binop");
1548+
}
1549+
// The precision check may have already ruled out overflow.
1550+
if (OverflowMaxOutputBits < IntSz) {
1551+
NeedsOverflowCheck = false;
1552+
// We can bound unsigned overflow from sub to in range signed value (this is
1553+
// what allows us to avoid the overflow check for sub).
1554+
if (IntOpc == Instruction::Sub)
1555+
OutputSigned = true;
1556+
}
1557+
1558+
// Precision check did not rule out overflow, so need to check.
1559+
if (NeedsOverflowCheck &&
1560+
!willNotOverflow(IntOpc, IntOps[0], IntOps[1], BO, OutputSigned))
1561+
return nullptr;
1562+
1563+
Value *IntBinOp = Builder.CreateBinOp(IntOpc, IntOps[0], IntOps[1]);
1564+
if (OutputSigned)
1565+
return new SIToFPInst(IntBinOp, FPTy);
1566+
return new UIToFPInst(IntBinOp, FPTy);
14701567
}
14711568

14721569
/// A binop with a constant operand and a sign-extended boolean operand may be

llvm/test/Transforms/InstCombine/add-sitofp.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ define float @test_3(i32 %a, i32 %b) {
9090
; CHECK-LABEL: @test_3(
9191
; CHECK-NEXT: [[M:%.*]] = lshr i32 [[A:%.*]], 24
9292
; CHECK-NEXT: [[N:%.*]] = and i32 [[M]], [[B:%.*]]
93-
; CHECK-NEXT: [[O:%.*]] = sitofp i32 [[N]] to float
94-
; CHECK-NEXT: [[P:%.*]] = fadd float [[O]], 1.000000e+00
93+
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[N]], 1
94+
; CHECK-NEXT: [[P:%.*]] = sitofp i32 [[TMP1]] to float
9595
; CHECK-NEXT: ret float [[P]]
9696
;
9797
%m = lshr i32 %a, 24

0 commit comments

Comments
 (0)