Skip to content

[InstCombine] Add folds for (fp_binop ({s|u}itofp x), ({s|u}itofp y)) #82555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 6 additions & 57 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1867,64 +1867,10 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {

// Check for (fadd double (sitofp x), y), see if we can merge this into an
// integer add followed by a promotion.
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
Value *LHSIntVal = LHSConv->getOperand(0);
Type *FPType = LHSConv->getType();

// TODO: This check is overly conservative. In many cases known bits
// analysis can tell us that the result of the addition has less significant
// bits than the integer type can hold.
auto IsValidPromotion = [](Type *FTy, Type *ITy) {
Type *FScalarTy = FTy->getScalarType();
Type *IScalarTy = ITy->getScalarType();

// Do we have enough bits in the significand to represent the result of
// the integer addition?
unsigned MaxRepresentableBits =
APFloat::semanticsPrecision(FScalarTy->getFltSemantics());
return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits;
};

// (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
// ... if the constant fits in the integer value. This is useful for things
// like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
// requires a constant pool load, and generally allows the add to be better
// instcombined.
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP,
LHSIntVal->getType(), DL);
if (LHSConv->hasOneUse() &&
ConstantFoldCastOperand(Instruction::SIToFP, CI, I.getType(), DL) ==
CFP &&
willNotOverflowSignedAdd(LHSIntVal, CI, I)) {
// Insert the new integer add.
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
}

// (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
Value *RHSIntVal = RHSConv->getOperand(0);
// It's enough to check LHS types only because we require int types to
// be the same for this transform.
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
// Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of int->fp conversions),
// and if the integer add will not overflow.
if (LHSIntVal->getType() == RHSIntVal->getType() &&
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) {
// Insert the new integer add.
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
}
}
}
if (Instruction *R = foldFBinOpOfIntCasts(I))
return R;

Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
// Handle specials cases for FAdd with selects feeding the operation
if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS))
return replaceInstUsesWith(I, V);
Expand Down Expand Up @@ -2847,6 +2793,9 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;

if (Instruction *R = foldFBinOpOfIntCasts(I))
return R;

Value *X, *Y;
Constant *C;

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN);
Instruction *foldBitcastExtElt(ExtractElementInst &ExtElt);
Instruction *foldCastedBitwiseLogic(BinaryOperator &I);
Instruction *foldFBinOpOfIntCasts(BinaryOperator &I);
Instruction *foldBinopOfSextBoolToSelect(BinaryOperator &I);
Instruction *narrowBinOp(TruncInst &Trunc);
Instruction *narrowMaskedBinOp(BinaryOperator &And);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Instruction *R = foldFPSignBitOps(I))
return R;

if (Instruction *R = foldFBinOpOfIntCasts(I))
return R;

// X * -1.0 --> -X
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (match(Op1, m_SpecificFP(-1.0)))
Expand Down
173 changes: 173 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,179 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
return nullptr;
}

// Try to fold:
// 1) (fp_binop ({s|u}itofp x), ({s|u}itofp y))
// -> ({s|u}itofp (int_binop x, y))
// 2) (fp_binop ({s|u}itofp x), FpC)
// -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))
Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) {
Value *IntOps[2] = {nullptr, nullptr};
Constant *Op1FpC = nullptr;

// Check for:
// 1) (binop ({s|u}itofp x), ({s|u}itofp y))
// 2) (binop ({s|u}itofp x), FpC)
if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) &&
!match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0]))))
return nullptr;

if (!match(BO.getOperand(1), m_Constant(Op1FpC)) &&
!match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) &&
!match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1]))))
return nullptr;

Type *FPTy = BO.getType();
Type *IntTy = IntOps[0]->getType();

// Do we have signed casts?
bool OpsFromSigned = isa<SIToFPInst>(BO.getOperand(0));

unsigned IntSz = IntTy->getScalarSizeInBits();
// This is the maximum number of inuse bits by the integer where the int -> fp
// casts are exact.
unsigned MaxRepresentableBits =
APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics());
Comment on lines +1434 to +1435
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off by 1 depending on signed or unsigned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are overly conservative if IntSz == MaxRepresentableBits + 1 and its signed (not we can't expand the bound for say sitofp i16 to half b.c it won't sign extend). So for things like i12, i24, ...
I'll add comment to the affect, although don't think its really worth increase code complexity to handle a special case that will almost never actually apply (weird integer widths are rare).
Proofs:

; $> /home/noah/programs/opensource/llvm-dev/src/alive2/build/alive-tv (-smt-to=200000000)

----------------------------------------
define half @src_sisi_add_i12(i12 noundef %x, i12 noundef %y) {
#0:
  %overflow_info = sadd_overflow i12 noundef %x, noundef %y
  %does_overflow = extractvalue {i12, i1, i8} %overflow_info, 1
  %doesnot_overflow = xor i1 %does_overflow, 1
  assume i1 %doesnot_overflow
  %xf = sitofp i12 noundef %x to half
  %yf = sitofp i12 noundef %y to half
  %r = fadd half %xf, %yf
  ret half %r
}
=>
define half @tgt_sisi_add_i12(i12 noundef %x, i12 noundef %y) {
#0:
  %overflow_info = sadd_overflow i12 noundef %x, noundef %y
  %does_overflow = extractvalue {i12, i1, i8} %overflow_info, 1
  %doesnot_overflow = xor i1 %does_overflow, 1
  assume i1 %doesnot_overflow
  %xy = add i12 noundef %x, noundef %y
  %r = sitofp i12 %xy to half
  ret half %r
}
Transformation seems to be correct!


----------------------------------------
define half @src_uiui_add_i12(i12 noundef %x, i12 noundef %y) {
#0:
  %overflow_info = uadd_overflow i12 noundef %x, noundef %y
  %does_overflow = extractvalue {i12, i1, i8} %overflow_info, 1
  %doesnot_overflow = xor i1 %does_overflow, 1
  assume i1 %doesnot_overflow
  %xf = uitofp i12 noundef %x to half
  %yf = uitofp i12 noundef %y to half
  %r = fadd half %xf, %yf
  ret half %r
}
=>
define half @tgt_uiui_add_i12(i12 noundef %x, i12 noundef %y) {
#0:
  %overflow_info = uadd_overflow i12 noundef %x, noundef %y
  %does_overflow = extractvalue {i12, i1, i8} %overflow_info, 1
  %doesnot_overflow = xor i1 %does_overflow, 1
  assume i1 %doesnot_overflow
  %xy = add i12 noundef %x, noundef %y
  %r = uitofp i12 %xy to half
  ret half %r
}
Transformation doesn't verify!

ERROR: Value mismatch

Example:
i12 noundef %x = #x80d (2061, -2035)
i12 noundef %y = #x002 (2)

Source:
{i12, i1, i8} %overflow_info = { #x80f (2063, -2033), #x0 (0), poison }
i1 %does_overflow = #x0 (0)
i1 %doesnot_overflow = #x1 (1)
half %xf = #x6806 (2060)
half %yf = #x4000 (2)
half %r = #x6807 (2062)

Target:
{i12, i1, i8} %overflow_info = { #x80f (2063, -2033), #x0 (0), poison }
i1 %does_overflow = #x0 (0)
i1 %doesnot_overflow = #x1 (1)
i12 %xy = #x80f (2063, -2033)
half %r = #x6808 (2064)
Source value: #x6807 (2062)
Target value: #x6808 (2064)

Summary:
  1 correct transformations
  1 incorrect transformations
  0 failed-to-prove transformations
  0 Alive2 errors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't add the comment here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it at the check vs IntSz in isValidPromotion.


// Cache KnownBits a bit to potentially save some analysis.
WithCache<const Value *> OpsKnown[2] = {IntOps[0], IntOps[1]};

// Preserve known number of leading bits. This can allow us to trivial nsw/nuw
// checks later on.
unsigned NumUsedLeadingBits[2] = {IntSz, IntSz};

auto IsNonZero = [&](unsigned OpNo) -> bool {
if (OpsKnown[OpNo].hasKnownBits() &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (OpsKnown[OpNo].hasKnownBits() &&
if (

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we want to do this. We may unnecessarily compute knownbits. The idea isKnownNonZero is what we want to spend our compute on, we only check knownbits if we have it already as an early out. Same applies to nonneg.

OpsKnown[OpNo].getKnownBits(SQ).isNonZero())
return true;
return isKnownNonZero(IntOps[OpNo], SQ.DL);
};

auto IsNonNeg = [&](unsigned OpNo) -> bool {
if (OpsKnown[OpNo].hasKnownBits() &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (OpsKnown[OpNo].hasKnownBits() &&
if (

OpsKnown[OpNo].getKnownBits(SQ).isNonNegative())
return true;
return isKnownNonNegative(IntOps[OpNo], SQ);
};

// Check if we know for certain that ({s|u}itofp op) is exact.
auto IsValidPromotion = [&](unsigned OpNo) -> bool {
// If fp precision >= bitwidth(op) then its exact.
// NB: This is slightly conservative for `sitofp`. For signed conversion, we
// can handle `MaxRepresentableBits == IntSz - 1` as the sign bit will be
// handled specially. We can't, however, increase the bound arbitrarily for
// `sitofp` as for larger sizes, it won't sign extend.
if (MaxRepresentableBits < IntSz) {
// Otherwise if its signed cast check that fp precisions >= bitwidth(op) -
// numSignBits(op).
// TODO: If we add support for `WithCache` in `ComputeNumSignBits`, change
// `IntOps[OpNo]` arguments to `KnownOps[OpNo]`.
if (OpsFromSigned)
NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]);
// Finally for unsigned check that fp precision >= bitwidth(op) -
// numLeadingZeros(op).
else {
NumUsedLeadingBits[OpNo] =
IntSz - OpsKnown[OpNo].getKnownBits(SQ).countMinLeadingZeros();
}
}
// NB: We could also check if op is known to be a power of 2 or zero (which
// will always be representable). Its unlikely, however, that is we are
// unable to bound op in any way we will be able to pass the overflow checks
// later on.

if (MaxRepresentableBits < NumUsedLeadingBits[OpNo])
return false;
// Signed + Mul also requires that op is non-zero to avoid -0 cases.
return !OpsFromSigned || BO.getOpcode() != Instruction::FMul ||
IsNonZero(OpNo);
};

// If we have a constant rhs, see if we can losslessly convert it to an int.
if (Op1FpC != nullptr) {
Constant *Op1IntC = ConstantFoldCastOperand(
OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC,
IntTy, DL);
if (Op1IntC == nullptr)
return nullptr;
if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP
: Instruction::UIToFP,
Op1IntC, FPTy, DL) != Op1FpC)
return nullptr;

// First try to keep sign of cast the same.
IntOps[1] = Op1IntC;
}

// Ensure lhs/rhs integer types match.
if (IntTy != IntOps[1]->getType())
return nullptr;

if (Op1FpC == nullptr) {
if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(1))) {
// If we have a signed + unsigned, see if we can treat both as signed
// (uitofp nneg x) == (sitofp nneg x).
if (OpsFromSigned ? !IsNonNeg(1) : !IsNonNeg(0))
return nullptr;
OpsFromSigned = true;
}
if (!IsValidPromotion(1))
return nullptr;
}
if (!IsValidPromotion(0))
return nullptr;

// Final we check if the integer version of the binop will not overflow.
BinaryOperator::BinaryOps IntOpc;
// Because of the precision check, we can often rule out overflows.
bool NeedsOverflowCheck = true;
// Try to conservatively rule out overflow based on the already done precision
// checks.
unsigned OverflowMaxOutputBits = OpsFromSigned ? 2 : 1;
unsigned OverflowMaxCurBits =
std::max(NumUsedLeadingBits[0], NumUsedLeadingBits[1]);
bool OutputSigned = OpsFromSigned;
switch (BO.getOpcode()) {
case Instruction::FAdd:
IntOpc = Instruction::Add;
OverflowMaxOutputBits += OverflowMaxCurBits;
break;
case Instruction::FSub:
IntOpc = Instruction::Sub;
OverflowMaxOutputBits += OverflowMaxCurBits;
break;
case Instruction::FMul:
IntOpc = Instruction::Mul;
OverflowMaxOutputBits += OverflowMaxCurBits * 2;
break;
default:
llvm_unreachable("Unsupported binop");
}
// The precision check may have already ruled out overflow.
if (OverflowMaxOutputBits < IntSz) {
NeedsOverflowCheck = false;
// We can bound unsigned overflow from sub to in range signed value (this is
// what allows us to avoid the overflow check for sub).
if (IntOpc == Instruction::Sub)
OutputSigned = true;
}

// Precision check did not rule out overflow, so need to check.
// TODO: If we add support for `WithCache` in `willNotOverflow`, change
// `IntOps[...]` arguments to `KnownOps[...]`.
if (NeedsOverflowCheck &&
!willNotOverflow(IntOpc, IntOps[0], IntOps[1], BO, OutputSigned))
return nullptr;

Value *IntBinOp = Builder.CreateBinOp(IntOpc, IntOps[0], IntOps[1]);
if (auto *IntBO = dyn_cast<BinaryOperator>(IntBinOp)) {
IntBO->setHasNoSignedWrap(OutputSigned);
IntBO->setHasNoUnsignedWrap(!OutputSigned);
}
if (OutputSigned)
return new SIToFPInst(IntBinOp, FPTy);
return new UIToFPInst(IntBinOp, FPTy);
}

/// A binop with a constant operand and a sign-extended boolean operand may be
/// converted into a select of constants by applying the binary operation to
/// the constant with the two possible values of the extended boolean (0 or -1).
Expand Down
24 changes: 11 additions & 13 deletions llvm/test/Transforms/InstCombine/add-sitofp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ define double @x(i32 %a, i32 %b) {
; CHECK-LABEL: @x(
; CHECK-NEXT: [[M:%.*]] = lshr i32 [[A:%.*]], 24
; CHECK-NEXT: [[N:%.*]] = and i32 [[M]], [[B:%.*]]
; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw nsw i32 [[N]], 1
; CHECK-NEXT: [[P:%.*]] = sitofp i32 [[ADDCONV]] to double
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[N]], 1
; CHECK-NEXT: [[P:%.*]] = sitofp i32 [[TMP1]] to double
; CHECK-NEXT: ret double [[P]]
;
%m = lshr i32 %a, 24
Expand All @@ -19,8 +19,8 @@ define double @x(i32 %a, i32 %b) {
define double @test(i32 %a) {
; CHECK-LABEL: @test(
; CHECK-NEXT: [[A_AND:%.*]] = and i32 [[A:%.*]], 1073741823
; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw nsw i32 [[A_AND]], 1
; CHECK-NEXT: [[RES:%.*]] = sitofp i32 [[ADDCONV]] to double
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[A_AND]], 1
; CHECK-NEXT: [[RES:%.*]] = sitofp i32 [[TMP1]] to double
; CHECK-NEXT: ret double [[RES]]
;
; Drop two highest bits to guarantee that %a + 1 doesn't overflow
Expand Down Expand Up @@ -48,8 +48,8 @@ define double @test_2(i32 %a, i32 %b) {
; CHECK-LABEL: @test_2(
; CHECK-NEXT: [[A_AND:%.*]] = and i32 [[A:%.*]], 1073741823
; CHECK-NEXT: [[B_AND:%.*]] = and i32 [[B:%.*]], 1073741823
; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw nsw i32 [[A_AND]], [[B_AND]]
; CHECK-NEXT: [[RES:%.*]] = sitofp i32 [[ADDCONV]] to double
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[A_AND]], [[B_AND]]
; CHECK-NEXT: [[RES:%.*]] = sitofp i32 [[TMP1]] to double
; CHECK-NEXT: ret double [[RES]]
;
; Drop two highest bits to guarantee that %a + %b doesn't overflow
Expand Down Expand Up @@ -83,15 +83,13 @@ define float @test_2_neg(i32 %a, i32 %b) {
ret float %res
}

; This test demonstrates overly conservative legality check. The float addition
; can be replaced with the integer addition because the result of the operation
; can be represented in float, but we don't do that now.
; can be represented in float.
define float @test_3(i32 %a, i32 %b) {
; CHECK-LABEL: @test_3(
; CHECK-NEXT: [[M:%.*]] = lshr i32 [[A:%.*]], 24
; CHECK-NEXT: [[N:%.*]] = and i32 [[M]], [[B:%.*]]
; CHECK-NEXT: [[O:%.*]] = sitofp i32 [[N]] to float
; CHECK-NEXT: [[P:%.*]] = fadd float [[O]], 1.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[N]], 1
; CHECK-NEXT: [[P:%.*]] = sitofp i32 [[TMP1]] to float
; CHECK-NEXT: ret float [[P]]
;
%m = lshr i32 %a, 24
Expand All @@ -105,8 +103,8 @@ define <4 x double> @test_4(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: @test_4(
; CHECK-NEXT: [[A_AND:%.*]] = and <4 x i32> [[A:%.*]], <i32 1073741823, i32 1073741823, i32 1073741823, i32 1073741823>
; CHECK-NEXT: [[B_AND:%.*]] = and <4 x i32> [[B:%.*]], <i32 1073741823, i32 1073741823, i32 1073741823, i32 1073741823>
; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw nsw <4 x i32> [[A_AND]], [[B_AND]]
; CHECK-NEXT: [[RES:%.*]] = sitofp <4 x i32> [[ADDCONV]] to <4 x double>
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw <4 x i32> [[A_AND]], [[B_AND]]
; CHECK-NEXT: [[RES:%.*]] = sitofp <4 x i32> [[TMP1]] to <4 x double>
; CHECK-NEXT: ret <4 x double> [[RES]]
;
; Drop two highest bits to guarantee that %a + %b doesn't overflow
Expand Down
Loading