Skip to content

Commit 946ea4e

Browse files
committed
[InstCombine] Add folds for (fp_binop ({s|u}itofp x), ({s|u}itofp y))
The full fold is one of the following: 1) `(fp_binop ({s|u}itofp x), ({s|u}itofp y))` -> `({s|u}itofp (int_binop x, y))` 2) `(fp_binop ({s|u}itofp x), FpC)` -> `({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))` And support the following binops: `fmul` -> `mul` `fadd` -> `add` `fsub` -> `sub` Proofs: https://alive2.llvm.org/ce/z/zuacA8 The proofs timeout, so they must be reproduced locally. Closes llvm#82555
1 parent 0f5849e commit 946ea4e

File tree

6 files changed

+231
-150
lines changed

6 files changed

+231
-150
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2793,6 +2793,9 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
27932793
if (Instruction *X = foldFNegIntoConstant(I, DL))
27942794
return X;
27952795

2796+
if (Instruction *R = foldFBinOpOfIntCasts(I))
2797+
return R;
2798+
27962799
Value *X, *Y;
27972800
Constant *C;
27982801

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
769769
if (Instruction *R = foldFPSignBitOps(I))
770770
return R;
771771

772+
if (Instruction *R = foldFBinOpOfIntCasts(I))
773+
return R;
774+
772775
// X * -1.0 --> -X
773776
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
774777
if (match(Op1, m_SpecificFP(-1.0)))

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 164 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,71 +1402,176 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
14021402
}
14031403

14041404
// Try to fold:
1405-
// 1) (add (sitofp x), (sitofp y))
1406-
// -> (sitofp (add x, y))
1407-
// 2) (add (sitofp x), FpC)
1408-
// -> (sitofp (add x, (fptosi FpC)))
1405+
// 1) (fp_binop ({s|u}itofp x), ({s|u}itofp y))
1406+
// -> ({s|u}itofp (int_binop x, y))
1407+
// 2) (fp_binop ({s|u}itofp x), FpC)
1408+
// -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))
14091409
Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) {
1410-
// Check for (fadd double (sitofp x), y), see if we can merge this into an
1411-
// integer add followed by a promotion.
1412-
Value *LHS = BO.getOperand(0), *RHS = BO.getOperand(1);
1413-
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
1414-
Value *LHSIntVal = LHSConv->getOperand(0);
1415-
Type *FPType = LHSConv->getType();
1416-
1417-
// TODO: This check is overly conservative. In many cases known bits
1418-
// analysis can tell us that the result of the addition has less significant
1419-
// bits than the integer type can hold.
1420-
auto IsValidPromotion = [](Type *FTy, Type *ITy) {
1421-
Type *FScalarTy = FTy->getScalarType();
1422-
Type *IScalarTy = ITy->getScalarType();
1423-
1424-
// Do we have enough bits in the significand to represent the result of
1425-
// the integer addition?
1426-
unsigned MaxRepresentableBits =
1427-
APFloat::semanticsPrecision(FScalarTy->getFltSemantics());
1428-
return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits;
1429-
};
1410+
Value *IntOps[2] = {nullptr, nullptr};
1411+
Constant *Op1FpC = nullptr;
1412+
1413+
// Check for:
1414+
// 1) (binop ({s|u}itofp x), ({s|u}itofp y))
1415+
// 2) (binop ({s|u}itofp x), FpC)
1416+
if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) &&
1417+
!match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0]))))
1418+
return nullptr;
14301419

1431-
// (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
1432-
// ... if the constant fits in the integer value. This is useful for things
1433-
// like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
1434-
// requires a constant pool load, and generally allows the add to be better
1435-
// instcombined.
1436-
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
1437-
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1438-
Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP,
1439-
LHSIntVal->getType(), DL);
1440-
if (LHSConv->hasOneUse() &&
1441-
ConstantFoldCastOperand(Instruction::SIToFP, CI, BO.getType(),
1442-
DL) == CFP &&
1443-
willNotOverflowSignedAdd(LHSIntVal, CI, BO)) {
1444-
// Insert the new integer add.
1445-
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI);
1446-
return new SIToFPInst(NewAdd, BO.getType());
1447-
}
1448-
}
1420+
if (!match(BO.getOperand(1), m_Constant(Op1FpC)) &&
1421+
!match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) &&
1422+
!match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1]))))
1423+
return nullptr;
14491424

1450-
// (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
1451-
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
1452-
Value *RHSIntVal = RHSConv->getOperand(0);
1453-
// It's enough to check LHS types only because we require int types to
1454-
// be the same for this transform.
1455-
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1456-
// Only do this if x/y have the same type, if at least one of them has a
1457-
// single use (so we don't increase the number of int->fp conversions),
1458-
// and if the integer add will not overflow.
1459-
if (LHSIntVal->getType() == RHSIntVal->getType() &&
1460-
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1461-
willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, BO)) {
1462-
// Insert the new integer add.
1463-
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal);
1464-
return new SIToFPInst(NewAdd, BO.getType());
1465-
}
1425+
Type *FPTy = BO.getType();
1426+
Type *IntTy = IntOps[0]->getType();
1427+
1428+
// Do we have signed casts?
1429+
bool OpsFromSigned = isa<SIToFPInst>(BO.getOperand(0));
1430+
1431+
unsigned IntSz = IntTy->getScalarSizeInBits();
1432+
// This is the maximum number of inuse bits by the integer where the int -> fp
1433+
// casts are exact.
1434+
unsigned MaxRepresentableBits =
1435+
APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics());
1436+
1437+
// Cache KnownBits a bit to potentially save some analysis.
1438+
WithCache<const Value *> OpsKnown[2] = {IntOps[0], IntOps[1]};
1439+
1440+
// Preserve known number of leading bits. This can allow us to trivial nsw/nuw
1441+
// checks later on.
1442+
unsigned NumUsedLeadingBits[2] = {IntSz, IntSz};
1443+
1444+
auto IsNonZero = [&](unsigned OpNo) -> bool {
1445+
if (OpsKnown[OpNo].hasKnownBits() &&
1446+
OpsKnown[OpNo].getKnownBits(SQ).isNonZero())
1447+
return true;
1448+
return isKnownNonZero(IntOps[OpNo], SQ.DL);
1449+
};
1450+
1451+
auto IsNonNeg = [&](unsigned OpNo) -> bool {
1452+
if (OpsKnown[OpNo].hasKnownBits() &&
1453+
OpsKnown[OpNo].getKnownBits(SQ).isNonNegative())
1454+
return true;
1455+
return isKnownNonNegative(IntOps[OpNo], SQ);
1456+
};
1457+
1458+
// Check if we know for certain that ({s|u}itofp op) is exact.
1459+
auto IsValidPromotion = [&](unsigned OpNo) -> bool {
1460+
// If fp precision >= bitwidth(op) then its exact.
1461+
// NB: This is slightly conservative for `sitofp`. For signed conversion, we
1462+
// can handle `MaxRepresentableBits == IntSz - 1` as the sign bit will be
1463+
// handled specially. We can't, however, increase the bound arbitrarily for
1464+
// `sitofp` as for larger sizes, it won't sign extend.
1465+
if (MaxRepresentableBits < IntSz) {
1466+
// Otherwise if its signed cast check that fp precisions >= bitwidth(op) -
1467+
// numSignBits(op).
1468+
// TODO: If we add support for `WithCache` in `ComputeNumSignBits`, change
1469+
// `IntOps[OpNo]` arguments to `KnownOps[OpNo]`.
1470+
if (OpsFromSigned)
1471+
NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]);
1472+
// Finally for unsigned check that fp precision >= bitwidth(op) -
1473+
// numLeadingZeros(op).
1474+
else {
1475+
NumUsedLeadingBits[OpNo] =
1476+
IntSz - OpsKnown[OpNo].getKnownBits(SQ).countMinLeadingZeros();
14661477
}
14671478
}
1479+
// NB: We could also check if op is known to be a power of 2 or zero (which
1480+
// will always be representable). Its unlikely, however, that is we are
1481+
// unable to bound op in any way we will be able to pass the overflow checks
1482+
// later on.
1483+
1484+
if (MaxRepresentableBits < NumUsedLeadingBits[OpNo])
1485+
return false;
1486+
// Signed + Mul also requires that op is non-zero to avoid -0 cases.
1487+
return !OpsFromSigned || BO.getOpcode() != Instruction::FMul ||
1488+
IsNonZero(OpNo);
1489+
};
1490+
1491+
// If we have a constant rhs, see if we can losslessly convert it to an int.
1492+
if (Op1FpC != nullptr) {
1493+
Constant *Op1IntC = ConstantFoldCastOperand(
1494+
OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC,
1495+
IntTy, DL);
1496+
if (Op1IntC == nullptr)
1497+
return nullptr;
1498+
if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP
1499+
: Instruction::UIToFP,
1500+
Op1IntC, FPTy, DL) != Op1FpC)
1501+
return nullptr;
1502+
1503+
// First try to keep sign of cast the same.
1504+
IntOps[1] = Op1IntC;
14681505
}
1469-
return nullptr;
1506+
1507+
// Ensure lhs/rhs integer types match.
1508+
if (IntTy != IntOps[1]->getType())
1509+
return nullptr;
1510+
1511+
if (Op1FpC == nullptr) {
1512+
if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(1))) {
1513+
// If we have a signed + unsigned, see if we can treat both as signed
1514+
// (uitofp nneg x) == (sitofp nneg x).
1515+
if (OpsFromSigned ? !IsNonNeg(1) : !IsNonNeg(0))
1516+
return nullptr;
1517+
OpsFromSigned = true;
1518+
}
1519+
if (!IsValidPromotion(1))
1520+
return nullptr;
1521+
}
1522+
if (!IsValidPromotion(0))
1523+
return nullptr;
1524+
1525+
// Final we check if the integer version of the binop will not overflow.
1526+
BinaryOperator::BinaryOps IntOpc;
1527+
// Because of the precision check, we can often rule out overflows.
1528+
bool NeedsOverflowCheck = true;
1529+
// Try to conservatively rule out overflow based on the already done precision
1530+
// checks.
1531+
unsigned OverflowMaxOutputBits = OpsFromSigned ? 2 : 1;
1532+
unsigned OverflowMaxCurBits =
1533+
std::max(NumUsedLeadingBits[0], NumUsedLeadingBits[1]);
1534+
bool OutputSigned = OpsFromSigned;
1535+
switch (BO.getOpcode()) {
1536+
case Instruction::FAdd:
1537+
IntOpc = Instruction::Add;
1538+
OverflowMaxOutputBits += OverflowMaxCurBits;
1539+
break;
1540+
case Instruction::FSub:
1541+
IntOpc = Instruction::Sub;
1542+
OverflowMaxOutputBits += OverflowMaxCurBits;
1543+
break;
1544+
case Instruction::FMul:
1545+
IntOpc = Instruction::Mul;
1546+
OverflowMaxOutputBits += OverflowMaxCurBits * 2;
1547+
break;
1548+
default:
1549+
llvm_unreachable("Unsupported binop");
1550+
}
1551+
// The precision check may have already ruled out overflow.
1552+
if (OverflowMaxOutputBits < IntSz) {
1553+
NeedsOverflowCheck = false;
1554+
// We can bound unsigned overflow from sub to in range signed value (this is
1555+
// what allows us to avoid the overflow check for sub).
1556+
if (IntOpc == Instruction::Sub)
1557+
OutputSigned = true;
1558+
}
1559+
1560+
// Precision check did not rule out overflow, so need to check.
1561+
// TODO: If we add support for `WithCache` in `willNotOverflow`, change
1562+
// `IntOps[...]` arguments to `KnownOps[...]`.
1563+
if (NeedsOverflowCheck &&
1564+
!willNotOverflow(IntOpc, IntOps[0], IntOps[1], BO, OutputSigned))
1565+
return nullptr;
1566+
1567+
Value *IntBinOp = Builder.CreateBinOp(IntOpc, IntOps[0], IntOps[1]);
1568+
if (auto *IntBO = dyn_cast<BinaryOperator>(IntBinOp)) {
1569+
IntBO->setHasNoSignedWrap(OutputSigned);
1570+
IntBO->setHasNoUnsignedWrap(!OutputSigned);
1571+
}
1572+
if (OutputSigned)
1573+
return new SIToFPInst(IntBinOp, FPTy);
1574+
return new UIToFPInst(IntBinOp, FPTy);
14701575
}
14711576

14721577
/// A binop with a constant operand and a sign-extended boolean operand may be

llvm/test/Transforms/InstCombine/add-sitofp.ll

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,13 @@ define float @test_2_neg(i32 %a, i32 %b) {
8383
ret float %res
8484
}
8585

86-
; This test demonstrates overly conservative legality check. The float addition
87-
; can be replaced with the integer addition because the result of the operation
88-
; can be represented in float, but we don't do that now.
86+
; can be represented in float.
8987
define float @test_3(i32 %a, i32 %b) {
9088
; CHECK-LABEL: @test_3(
9189
; CHECK-NEXT: [[M:%.*]] = lshr i32 [[A:%.*]], 24
9290
; CHECK-NEXT: [[N:%.*]] = and i32 [[M]], [[B:%.*]]
93-
; CHECK-NEXT: [[O:%.*]] = sitofp i32 [[N]] to float
94-
; CHECK-NEXT: [[P:%.*]] = fadd float [[O]], 1.000000e+00
91+
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[N]], 1
92+
; CHECK-NEXT: [[P:%.*]] = sitofp i32 [[TMP1]] to float
9593
; CHECK-NEXT: ret float [[P]]
9694
;
9795
%m = lshr i32 %a, 24

0 commit comments

Comments
 (0)