Skip to content

[KnownBits] Make {s,u}{add,sub}_sat optimal #113096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2024

Conversation

goldsteinn
Copy link
Contributor

Changes are:
1) Make signed-overflow detection optimal
2) For signed-overflow, try to rule out direction even if we can't
totally rule out overflow.
3) Intersect add/sub assuming no overflow with possible overflow
clamping values as opposed to add/sub without the assumption.

@llvmbot llvmbot added llvm:support llvm:analysis Includes value tracking, cost tables and constant folding labels Oct 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-support

Author: None (goldsteinn)

Changes

Changes are:
1) Make signed-overflow detection optimal
2) For signed-overflow, try to rule out direction even if we can't
totally rule out overflow.
3) Intersect add/sub assuming no overflow with possible overflow
clamping values as opposed to add/sub without the assumption.


Full diff: https://github.com/llvm/llvm-project/pull/113096.diff

3 Files Affected:

  • (modified) llvm/lib/Support/KnownBits.cpp (+73-65)
  • (modified) llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll (+1-8)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+4-8)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..22a1628b0fa23a 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -610,28 +610,78 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
                                      const KnownBits &RHS) {
   // We don't see NSW even for sadd/ssub as we want to check if the result has
   // signed overflow.
-  KnownBits Res =
-      KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
-  unsigned BitWidth = Res.getBitWidth();
-  auto SignBitKnown = [&](const KnownBits &K) {
-    return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
-  };
-  std::optional<bool> Overflow;
+  unsigned BitWidth = LHS.getBitWidth();
 
+  std::optional<bool> Overflow;
+  // Even if we can't entirely rule out overflow, we may be able to rule out
+  // overflow in one direction. This allows us to potentially keep some of the
+  // add/sub bits. I.e if we can't overflow in the positive direction we won't
+  // clamp to INT_MAX so we can keep low 0s from the add/sub result.
+  bool MayNegClamp = true;
+  bool MayPosClamp = true;
   if (Signed) {
-    // If we can actually detect overflow do so. Otherwise leave Overflow as
-    // nullopt (we assume it may have happened).
-    if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+    // Easy cases we can rule out any overflow.
+    if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
+                (LHS.isNonNegative() && RHS.isNegative())))
+      Overflow = false;
+    else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
+                       (LHS.isNonNegative() && RHS.isNonNegative()))))
+      Overflow = false;
+    else {
+      // Check if we may overflow. If we can't rule out overflow then check if
+      // we can rule out a direction at least.
+      KnownBits UnsignedLHS = LHS;
+      KnownBits UnsignedRHS = RHS;
+      UnsignedLHS.One.clearSignBit();
+      UnsignedLHS.Zero.setSignBit();
+      UnsignedRHS.One.clearSignBit();
+      UnsignedRHS.Zero.setSignBit();
+      KnownBits Res =
+          KnownBits::computeForAddSub(Add, /*NSW=*/false,
+                                      /*NUW=*/false, UnsignedLHS, UnsignedRHS);
       if (Add) {
-        // sadd.sat
-        Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
-                    Res.isNonNegative() != LHS.isNonNegative());
+        if (Res.isNegative()) {
+          // Only overflow scenario is Pos + Pos.
+          MayNegClamp = false;
+          // Pos + Pos will overflow with extra signbit.
+          if (LHS.isNonNegative() && RHS.isNonNegative())
+            Overflow = true;
+        } else if (Res.isNonNegative()) {
+          // Only overflow scenario is Neg + Neg
+          MayPosClamp = false;
+          // Neg + Neg will overflow without extra signbit.
+          if (LHS.isNegative() && RHS.isNegative())
+            Overflow = true;
+        }
+        // We will never clamp to the opposite sign of N-bit result.
+        if (LHS.isNegative() || RHS.isNegative())
+          MayPosClamp = false;
+        if (LHS.isNonNegative() || RHS.isNonNegative())
+          MayNegClamp = false;
       } else {
-        // ssub.sat
-        Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
-                    Res.isNonNegative() != LHS.isNonNegative());
+        if (Res.isNegative()) {
+          // Only overflow scenario is Neg - Pos.
+          MayPosClamp = false;
+          // Neg - Pos will overflow with extra signbit.
+          if (LHS.isNegative() && RHS.isNonNegative())
+            Overflow = true;
+        } else if (Res.isNonNegative()) {
+          // Only overflow scenario is Pos - Neg.
+          MayNegClamp = false;
+          // Pos - Neg will overflow without extra signbit.
+          if (LHS.isNonNegative() && RHS.isNegative())
+            Overflow = true;
+        }
+        // We will never clamp to the opposite sign of N-bit result.
+        if (LHS.isNegative() || RHS.isNonNegative())
+          MayPosClamp = false;
+        if (LHS.isNonNegative() || RHS.isNegative())
+          MayNegClamp = false;
       }
     }
+    // If we have ruled out all clamping, we will never overflow.
+    if (!MayNegClamp && !MayPosClamp)
+      Overflow = false;
   } else if (Add) {
     // uadd.sat
     bool Of;
@@ -656,52 +706,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
     }
   }
 
-  if (Signed) {
-    if (Add) {
-      if (LHS.isNonNegative() && RHS.isNonNegative()) {
-        // Pos + Pos -> Pos
-        Res.One.clearSignBit();
-        Res.Zero.setSignBit();
-      }
-      if (LHS.isNegative() && RHS.isNegative()) {
-        // Neg + Neg -> Neg
-        Res.One.setSignBit();
-        Res.Zero.clearSignBit();
-      }
-    } else {
-      if (LHS.isNegative() && RHS.isNonNegative()) {
-        // Neg - Pos -> Neg
-        Res.One.setSignBit();
-        Res.Zero.clearSignBit();
-      } else if (LHS.isNonNegative() && RHS.isNegative()) {
-        // Pos - Neg -> Pos
-        Res.One.clearSignBit();
-        Res.Zero.setSignBit();
-      }
-    }
-  } else {
-    // Add: Leading ones of either operand are preserved.
-    // Sub: Leading zeros of LHS and leading ones of RHS are preserved
-    // as leading zeros in the result.
-    unsigned LeadingKnown;
-    if (Add)
-      LeadingKnown =
-          std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
-    else
-      LeadingKnown =
-          std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
-
-    // We select between the operation result and all-ones/zero
-    // respectively, so we can preserve known ones/zeros.
-    APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
-    if (Add) {
-      Res.One |= Mask;
-      Res.Zero &= ~Mask;
-    } else {
-      Res.Zero |= Mask;
-      Res.One &= ~Mask;
-    }
-  }
+  KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
+                                              /*NUW=*/!Signed, LHS, RHS);
 
   if (Overflow) {
     // We know whether or not we overflowed.
@@ -714,7 +720,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
     APInt C;
     if (Signed) {
       // sadd.sat / ssub.sat
-      assert(SignBitKnown(LHS) &&
+      assert(!LHS.isSignUnknown() &&
              "We somehow know overflow without knowing input sign");
       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
                            : APInt::getSignedMaxValue(BitWidth);
@@ -735,8 +741,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
   if (Signed) {
     // sadd.sat/ssub.sat
     // We can keep our information about the sign bits.
-    Res.Zero.clearLowBits(BitWidth - 1);
-    Res.One.clearLowBits(BitWidth - 1);
+    if (MayPosClamp)
+      Res.Zero.clearLowBits(BitWidth - 1);
+    if (MayNegClamp)
+      Res.One.clearLowBits(BitWidth - 1);
   } else if (Add) {
     // uadd.sat
     // We need to clear all the known zeros as we can only use the leading ones.
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
index c2926eaffa58c5..f9618e1ddbc022 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
 
 define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
 ; CHECK-LABEL: @ssub_sat_fail_may_overflow(
-; CHECK-NEXT:    [[XX:%.*]] = and i8 [[X:%.*]], 15
-; CHECK-NEXT:    [[YY:%.*]] = and i8 [[Y:%.*]], 15
-; CHECK-NEXT:    [[LHS:%.*]] = or i8 [[XX]], 1
-; CHECK-NEXT:    [[RHS:%.*]] = and i8 [[YY]], -2
-; CHECK-NEXT:    [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[EXP]], 1
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 false
 ;
   %xx = and i8 %x, 15
   %yy = and i8 %y, 15
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 551c1a8107494b..85b068e5725ce4 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       "sadd_sat", KnownBits::sadd_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.sadd_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "uadd_sat", KnownBits::uadd_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.uadd_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "ssub_sat", KnownBits::ssub_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.ssub_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "usub_sat", KnownBits::usub_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.usub_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "shl",
       [](const KnownBits &Known1, const KnownBits &Known2) {

@goldsteinn goldsteinn requested review from jayfoad and RKSimon October 20, 2024 17:46
@goldsteinn
Copy link
Contributor Author

ping

@RKSimon
Copy link
Collaborator

RKSimon commented Oct 30, 2024

please can you rebase to get the CI to run again?

Changes are:
    1) Make signed-overflow detection optimal
    2) For signed-overflow, try to rule out direction even if we can't
       totally rule out overflow.
    3) Intersect add/sub assuming no overflow with possible overflow
       clamping values as opposed to add/sub without the assumption.
@goldsteinn goldsteinn force-pushed the goldsteinn/knownbits-add-sat-wait branch from 9a4b687 to a861421 Compare October 30, 2024 17:50
@goldsteinn
Copy link
Contributor Author

Rebased

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jayfoad any thoughts?

@jayfoad
Copy link
Contributor

jayfoad commented Nov 4, 2024

jayfoad any thoughts?

Only that if the exhaustive unit test passes then the logic must be OK, so no objections from me. I don't have time to try to understand the patch deeply now.

@goldsteinn goldsteinn merged commit 877cb9a into llvm:main Nov 5, 2024
8 checks passed
PhilippRados pushed a commit to PhilippRados/llvm-project that referenced this pull request Nov 6, 2024
Changes are:
    1) Make signed-overflow detection optimal
    2) For signed-overflow, try to rule out direction even if we can't
       totally rule out overflow.
    3) Intersect add/sub assuming no overflow with possible overflow
       clamping values as opposed to add/sub without the assumption.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants