-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[KnownBits] Make {s,u}{add,sub}_sat
optimal
#113096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KnownBits] Make {s,u}{add,sub}_sat
optimal
#113096
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-support Author: None (goldsteinn) ChangesChanges are: Full diff: https://github.com/llvm/llvm-project/pull/113096.diff 3 Files Affected:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..22a1628b0fa23a 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -610,28 +610,78 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
const KnownBits &RHS) {
// We don't see NSW even for sadd/ssub as we want to check if the result has
// signed overflow.
- KnownBits Res =
- KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
- unsigned BitWidth = Res.getBitWidth();
- auto SignBitKnown = [&](const KnownBits &K) {
- return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
- };
- std::optional<bool> Overflow;
+ unsigned BitWidth = LHS.getBitWidth();
+ std::optional<bool> Overflow;
+ // Even if we can't entirely rule out overflow, we may be able to rule out
+ // overflow in one direction. This allows us to potentially keep some of the
+ // add/sub bits. I.e if we can't overflow in the positive direction we won't
+ // clamp to INT_MAX so we can keep low 0s from the add/sub result.
+ bool MayNegClamp = true;
+ bool MayPosClamp = true;
if (Signed) {
- // If we can actually detect overflow do so. Otherwise leave Overflow as
- // nullopt (we assume it may have happened).
- if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+ // Easy cases we can rule out any overflow.
+ if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
+ (LHS.isNonNegative() && RHS.isNegative())))
+ Overflow = false;
+ else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
+ (LHS.isNonNegative() && RHS.isNonNegative()))))
+ Overflow = false;
+ else {
+ // Check if we may overflow. If we can't rule out overflow then check if
+ // we can rule out a direction at least.
+ KnownBits UnsignedLHS = LHS;
+ KnownBits UnsignedRHS = RHS;
+ UnsignedLHS.One.clearSignBit();
+ UnsignedLHS.Zero.setSignBit();
+ UnsignedRHS.One.clearSignBit();
+ UnsignedRHS.Zero.setSignBit();
+ KnownBits Res =
+ KnownBits::computeForAddSub(Add, /*NSW=*/false,
+ /*NUW=*/false, UnsignedLHS, UnsignedRHS);
if (Add) {
- // sadd.sat
- Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
- Res.isNonNegative() != LHS.isNonNegative());
+ if (Res.isNegative()) {
+ // Only overflow scenario is Pos + Pos.
+ MayNegClamp = false;
+ // Pos + Pos will overflow with extra signbit.
+ if (LHS.isNonNegative() && RHS.isNonNegative())
+ Overflow = true;
+ } else if (Res.isNonNegative()) {
+ // Only overflow scenario is Neg + Neg
+ MayPosClamp = false;
+ // Neg + Neg will overflow without extra signbit.
+ if (LHS.isNegative() && RHS.isNegative())
+ Overflow = true;
+ }
+ // We will never clamp to the opposite sign of N-bit result.
+ if (LHS.isNegative() || RHS.isNegative())
+ MayPosClamp = false;
+ if (LHS.isNonNegative() || RHS.isNonNegative())
+ MayNegClamp = false;
} else {
- // ssub.sat
- Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
- Res.isNonNegative() != LHS.isNonNegative());
+ if (Res.isNegative()) {
+ // Only overflow scenario is Neg - Pos.
+ MayPosClamp = false;
+ // Neg - Pos will overflow with extra signbit.
+ if (LHS.isNegative() && RHS.isNonNegative())
+ Overflow = true;
+ } else if (Res.isNonNegative()) {
+ // Only overflow scenario is Pos - Neg.
+ MayNegClamp = false;
+ // Pos - Neg will overflow without extra signbit.
+ if (LHS.isNonNegative() && RHS.isNegative())
+ Overflow = true;
+ }
+ // We will never clamp to the opposite sign of N-bit result.
+ if (LHS.isNegative() || RHS.isNonNegative())
+ MayPosClamp = false;
+ if (LHS.isNonNegative() || RHS.isNegative())
+ MayNegClamp = false;
}
}
+ // If we have ruled out all clamping, we will never overflow.
+ if (!MayNegClamp && !MayPosClamp)
+ Overflow = false;
} else if (Add) {
// uadd.sat
bool Of;
@@ -656,52 +706,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
}
}
- if (Signed) {
- if (Add) {
- if (LHS.isNonNegative() && RHS.isNonNegative()) {
- // Pos + Pos -> Pos
- Res.One.clearSignBit();
- Res.Zero.setSignBit();
- }
- if (LHS.isNegative() && RHS.isNegative()) {
- // Neg + Neg -> Neg
- Res.One.setSignBit();
- Res.Zero.clearSignBit();
- }
- } else {
- if (LHS.isNegative() && RHS.isNonNegative()) {
- // Neg - Pos -> Neg
- Res.One.setSignBit();
- Res.Zero.clearSignBit();
- } else if (LHS.isNonNegative() && RHS.isNegative()) {
- // Pos - Neg -> Pos
- Res.One.clearSignBit();
- Res.Zero.setSignBit();
- }
- }
- } else {
- // Add: Leading ones of either operand are preserved.
- // Sub: Leading zeros of LHS and leading ones of RHS are preserved
- // as leading zeros in the result.
- unsigned LeadingKnown;
- if (Add)
- LeadingKnown =
- std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
- else
- LeadingKnown =
- std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
-
- // We select between the operation result and all-ones/zero
- // respectively, so we can preserve known ones/zeros.
- APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
- if (Add) {
- Res.One |= Mask;
- Res.Zero &= ~Mask;
- } else {
- Res.Zero |= Mask;
- Res.One &= ~Mask;
- }
- }
+ KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
+ /*NUW=*/!Signed, LHS, RHS);
if (Overflow) {
// We know whether or not we overflowed.
@@ -714,7 +720,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
APInt C;
if (Signed) {
// sadd.sat / ssub.sat
- assert(SignBitKnown(LHS) &&
+ assert(!LHS.isSignUnknown() &&
"We somehow know overflow without knowing input sign");
C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
: APInt::getSignedMaxValue(BitWidth);
@@ -735,8 +741,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
if (Signed) {
// sadd.sat/ssub.sat
// We can keep our information about the sign bits.
- Res.Zero.clearLowBits(BitWidth - 1);
- Res.One.clearLowBits(BitWidth - 1);
+ if (MayPosClamp)
+ Res.Zero.clearLowBits(BitWidth - 1);
+ if (MayNegClamp)
+ Res.One.clearLowBits(BitWidth - 1);
} else if (Add) {
// uadd.sat
// We need to clear all the known zeros as we can only use the leading ones.
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
index c2926eaffa58c5..f9618e1ddbc022 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
; CHECK-LABEL: @ssub_sat_fail_may_overflow(
-; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15
-; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15
-; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1
-; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2
-; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1
-; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT: ret i1 [[R]]
+; CHECK-NEXT: ret i1 false
;
%xx = and i8 %x, 15
%yy = and i8 %y, 15
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 551c1a8107494b..85b068e5725ce4 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
"sadd_sat", KnownBits::sadd_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.sadd_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"uadd_sat", KnownBits::uadd_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.uadd_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"ssub_sat", KnownBits::ssub_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.ssub_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"usub_sat", KnownBits::usub_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.usub_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"shl",
[](const KnownBits &Known1, const KnownBits &Known2) {
|
ping |
please can you rebase to get the CI to run again? |
Changes are: 1) Make signed-overflow detection optimal 2) For signed-overflow, try to rule out direction even if we can't totally rule out overflow. 3) Intersect add/sub assuming no overflow with possible overflow clamping values as opposed to add/sub without the assumption.
9a4b687
to
a861421
Compare
Rebased |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jayfoad any thoughts?
Only that if the exhaustive unit test passes then the logic must be OK, so no objections from me. I don't have time to try to understand the patch deeply now. |
Changes are: 1) Make signed-overflow detection optimal 2) For signed-overflow, try to rule out direction even if we can't totally rule out overflow. 3) Intersect add/sub assuming no overflow with possible overflow clamping values as opposed to add/sub without the assumption.
Changes are:
1) Make signed-overflow detection optimal
2) For signed-overflow, try to rule out direction even if we can't
totally rule out overflow.
3) Intersect add/sub assuming no overflow with possible overflow
clamping values as opposed to add/sub without the assumption.