Skip to content

[KnownBits] Make {s,u}{add,sub}_sat optimal #113096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 77 additions & 65 deletions llvm/lib/Support/KnownBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,28 +610,82 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
const KnownBits &RHS) {
// We don't see NSW even for sadd/ssub as we want to check if the result has
// signed overflow.
KnownBits Res =
KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
unsigned BitWidth = Res.getBitWidth();
auto SignBitKnown = [&](const KnownBits &K) {
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
};
std::optional<bool> Overflow;
unsigned BitWidth = LHS.getBitWidth();

std::optional<bool> Overflow;
// Even if we can't entirely rule out overflow, we may be able to rule out
// overflow in one direction. This allows us to potentially keep some of the
// add/sub bits. I.e if we can't overflow in the positive direction we won't
// clamp to INT_MAX so we can keep low 0s from the add/sub result.
bool MayNegClamp = true;
bool MayPosClamp = true;
if (Signed) {
// If we can actually detect overflow do so. Otherwise leave Overflow as
// nullopt (we assume it may have happened).
if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
// Easy cases we can rule out any overflow.
if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
(LHS.isNonNegative() && RHS.isNegative())))
Overflow = false;
else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
(LHS.isNonNegative() && RHS.isNonNegative()))))
Overflow = false;
else {
// Check if we may overflow. If we can't rule out overflow then check if
// we can rule out a direction at least.
KnownBits UnsignedLHS = LHS;
KnownBits UnsignedRHS = RHS;
// Get version of LHS/RHS with clearer signbit. This allows us to detect
// how the addition/subtraction might overflow into the signbit. Then
// using the actual known signbits of LHS/RHS, we can figure out which
// overflows are/aren't possible.
UnsignedLHS.One.clearSignBit();
UnsignedLHS.Zero.setSignBit();
UnsignedRHS.One.clearSignBit();
UnsignedRHS.Zero.setSignBit();
KnownBits Res =
KnownBits::computeForAddSub(Add, /*NSW=*/false,
/*NUW=*/false, UnsignedLHS, UnsignedRHS);
if (Add) {
// sadd.sat
Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
Res.isNonNegative() != LHS.isNonNegative());
if (Res.isNegative()) {
// Only overflow scenario is Pos + Pos.
MayNegClamp = false;
// Pos + Pos will overflow with extra signbit.
if (LHS.isNonNegative() && RHS.isNonNegative())
Overflow = true;
} else if (Res.isNonNegative()) {
// Only overflow scenario is Neg + Neg
MayPosClamp = false;
// Neg + Neg will overflow without extra signbit.
if (LHS.isNegative() && RHS.isNegative())
Overflow = true;
}
// We will never clamp to the opposite sign of N-bit result.
if (LHS.isNegative() || RHS.isNegative())
MayPosClamp = false;
if (LHS.isNonNegative() || RHS.isNonNegative())
MayNegClamp = false;
} else {
// ssub.sat
Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
Res.isNonNegative() != LHS.isNonNegative());
if (Res.isNegative()) {
// Only overflow scenario is Neg - Pos.
MayPosClamp = false;
// Neg - Pos will overflow with extra signbit.
if (LHS.isNegative() && RHS.isNonNegative())
Overflow = true;
} else if (Res.isNonNegative()) {
// Only overflow scenario is Pos - Neg.
MayNegClamp = false;
// Pos - Neg will overflow without extra signbit.
if (LHS.isNonNegative() && RHS.isNegative())
Overflow = true;
}
// We will never clamp to the opposite sign of N-bit result.
if (LHS.isNegative() || RHS.isNonNegative())
MayPosClamp = false;
if (LHS.isNonNegative() || RHS.isNegative())
MayNegClamp = false;
}
}
// If we have ruled out all clamping, we will never overflow.
if (!MayNegClamp && !MayPosClamp)
Overflow = false;
} else if (Add) {
// uadd.sat
bool Of;
Expand All @@ -656,52 +710,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
}
}

if (Signed) {
if (Add) {
if (LHS.isNonNegative() && RHS.isNonNegative()) {
// Pos + Pos -> Pos
Res.One.clearSignBit();
Res.Zero.setSignBit();
}
if (LHS.isNegative() && RHS.isNegative()) {
// Neg + Neg -> Neg
Res.One.setSignBit();
Res.Zero.clearSignBit();
}
} else {
if (LHS.isNegative() && RHS.isNonNegative()) {
// Neg - Pos -> Neg
Res.One.setSignBit();
Res.Zero.clearSignBit();
} else if (LHS.isNonNegative() && RHS.isNegative()) {
// Pos - Neg -> Pos
Res.One.clearSignBit();
Res.Zero.setSignBit();
}
}
} else {
// Add: Leading ones of either operand are preserved.
// Sub: Leading zeros of LHS and leading ones of RHS are preserved
// as leading zeros in the result.
unsigned LeadingKnown;
if (Add)
LeadingKnown =
std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
else
LeadingKnown =
std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());

// We select between the operation result and all-ones/zero
// respectively, so we can preserve known ones/zeros.
APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
if (Add) {
Res.One |= Mask;
Res.Zero &= ~Mask;
} else {
Res.Zero |= Mask;
Res.One &= ~Mask;
}
}
KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
/*NUW=*/!Signed, LHS, RHS);

if (Overflow) {
// We know whether or not we overflowed.
Expand All @@ -714,7 +724,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
APInt C;
if (Signed) {
// sadd.sat / ssub.sat
assert(SignBitKnown(LHS) &&
assert(!LHS.isSignUnknown() &&
"We somehow know overflow without knowing input sign");
C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
: APInt::getSignedMaxValue(BitWidth);
Expand All @@ -735,8 +745,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
if (Signed) {
// sadd.sat/ssub.sat
// We can keep our information about the sign bits.
Res.Zero.clearLowBits(BitWidth - 1);
Res.One.clearLowBits(BitWidth - 1);
if (MayPosClamp)
Res.Zero.clearLowBits(BitWidth - 1);
if (MayNegClamp)
Res.One.clearLowBits(BitWidth - 1);
} else if (Add) {
// uadd.sat
// We need to clear all the known zeros as we can only use the leading ones.
Expand Down
9 changes: 1 addition & 8 deletions llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {

define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
; CHECK-LABEL: @ssub_sat_fail_may_overflow(
; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15
; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15
; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1
; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2
; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: ret i1 [[R]]
; CHECK-NEXT: ret i1 false
;
%xx = and i8 %x, 15
%yy = and i8 %y, 15
Expand Down
12 changes: 4 additions & 8 deletions llvm/unittests/Support/KnownBitsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
"sadd_sat", KnownBits::sadd_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.sadd_sat(N2);
},
/*CheckOptimality=*/false);
});
testBinaryOpExhaustive(
"uadd_sat", KnownBits::uadd_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.uadd_sat(N2);
},
/*CheckOptimality=*/false);
});
testBinaryOpExhaustive(
"ssub_sat", KnownBits::ssub_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.ssub_sat(N2);
},
/*CheckOptimality=*/false);
});
testBinaryOpExhaustive(
"usub_sat", KnownBits::usub_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.usub_sat(N2);
},
/*CheckOptimality=*/false);
});
testBinaryOpExhaustive(
"shl",
[](const KnownBits &Known1, const KnownBits &Known2) {
Expand Down
Loading