Skip to content

KnownBits: generalize high-bits of mul to overflows #114211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

artagnon
Copy link
Contributor

Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.

-- 8< --
Based on #113051.

KnownBits::mul suffers from the deficiency that it doesn't account for
signed inputs. Fix it by refining known leading zeros when both inputs
are signed, and setting known leading ones when one of the inputs is
signed. The strategy we've used is to still use umul_ov, after adjusting
for signed inputs, and setting known leading ones from the negation of
the result, when it is known to be negative, noting that a possibly-zero
result is a special case.
@llvmbot llvmbot added llvm:support llvm:analysis Includes value tracking, cost tables and constant folding labels Oct 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-llvm-support

Author: Ramkumar Ramachandra (artagnon)

Changes

Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.

-- 8< --
Based on #113051.


Full diff: https://github.com/llvm/llvm-project/pull/114211.diff

3 Files Affected:

  • (modified) llvm/lib/Support/KnownBits.cpp (+88-13)
  • (added) llvm/test/Analysis/ValueTracking/knownbits-mul.ll (+143)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+148-1)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..c2d7c776725088 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,93 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   assert((!NoUndefSelfMultiply || LHS == RHS) &&
          "Self multiplication knownbits mismatch");
 
-  // Compute the high known-0 bits by multiplying the unsigned max of each side.
-  // Conservatively, M active bits * N active bits results in M + N bits in the
-  // result. But if we know a value is a power-of-2 for example, then this
-  // computes one more leading zero.
-  // TODO: This could be generalized to number of sign bits (negative numbers).
-  APInt UMaxLHS = LHS.getMaxValue();
-  APInt UMaxRHS = RHS.getMaxValue();
-
-  // For leading zeros in the result to be valid, the unsigned max product must
-  // fit in the bitwidth (it must not overflow).
+  // Compute the high known-0 or known-1 bits by multiplying the min and max of
+  // each side.
+  APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+        MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+        MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+  // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+  // overflow. However, if MaxProduct overflows, there is no guarantee on the
+  // MinProduct overflowing.
   bool HasOverflow;
-  APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
-  unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+  APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+        MinProduct = MinLHS * MinRHS;
+
+  bool OpsSignMatch = LHS.isNegative() == RHS.isNegative();
+  if (!OpsSignMatch) {
+    // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+    // negated to turn them into the corresponding signed-multiplication
+    // wrapped values.
+    MinProduct.negate();
+    MaxProduct.negate();
+  }
+
+  // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+  // leading zeros or ones in the result.
+  unsigned LeadZ = 0, LeadO = 0;
+  if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+    APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+          RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+    // A product of M active bits * N active bits results in M + N bits in the
+    // result. If either of the operands is a power of two, the result has one
+    // less active bit.
+    auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+      if (A.isZero() || B.isZero())
+        return 0;
+      return A.getActiveBits() + B.getActiveBits() -
+             (A.isPowerOf2() || B.isPowerOf2());
+    };
+
+    // We want to compute the number of active bits in the difference between
+    // the non-wrapped max product and non-wrapped min product, but we want to
+    // avoid camputing the non-wrapped max/min product.
+    unsigned ActiveBitsInDiff;
+    if (MinLHS.isZero() && MinRHS.isZero())
+      ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+    else
+      ActiveBitsInDiff =
+          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+          ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+    // We uniformly handle the case where there is no max-overflow, in which
+    // case the high zeros and ones are computed optimally, and where there is,
+    // but the result shifts at most by BitWidth, in which case the high zeros
+    // and ones are not computed optimally.
+    if (!HasOverflow || ActiveBitsInDiff <= BitWidth) {
+      // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+      // and B is zero.
+      auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+        return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+      };
+
+      // If we're shifting by BitWidth, MaxProduct and MinProduct are swapped.
+      bool MinMaxSwap = ActiveBitsInDiff == BitWidth;
+      if (MinMaxSwap)
+        std::swap(MaxProduct, MinProduct);
+
+      if (OpsSignMatch != MinMaxSwap) {
+        // Normally, this is the case for when the signs of LHS and RHS match,
+        // and the else branch is for when the signs mismatch. However, if min
+        // and max were swapped, we need to invert these cases.
+        if (UgtCheckCorner(MaxProduct, MinProduct)) {
+          // Normally, when the signs of LHS and RHS match, we can safely set
+          // leading zeros of the result. However, if both MaxProduct and
+          // MinProduct are negative, we can also set the leading ones.
+          LeadZ = MaxProduct.countLeadingZeros();
+          LeadO = (MaxProduct & MinProduct).countLeadingOnes();
+        }
+      } else if (UgtCheckCorner(MinProduct, MaxProduct)) {
+        // Normally, when the signs of LHS and RHS mismatch, we can safely set
+        // leading ones of the result. However, if both MaxProduct and
+        // MinProduct are non-negative, we can also set the leading zeros.
+        LeadO = MaxProduct.countLeadingOnes();
+        LeadZ = (MaxProduct | MinProduct).countLeadingZeros();
+      }
+    }
+  }
 
   // The result of the bottom bits of an integer multiply can be
   // inferred by looking at the bottom bits of both operands and
@@ -873,8 +947,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
 
   KnownBits Res(BitWidth);
   Res.Zero.setHighBits(LeadZ);
+  Res.One.setHighBits(LeadO);
   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
-  Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+  Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
 
   // If we're self-multiplying then bit[1] is guaranteed to be zero.
   if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..37526c67f0d9e1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 2
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %x.notsmin = or i8 %x, 3
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x.notsmin, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 -16
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %y.nonzero = or i8 %y, 1
+  %mul = mul i8 %x, %y.nonzero
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 8
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
   }
 }
 
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
   for (unsigned Bits : {1, 4}) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
       ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) {
   }
 }
 
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+  for (unsigned Bits : {1, 4}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits), WideExact(2 * Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        bool HasOverflow;
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            // The final value of HasOverflow corresponds to the multiplication
+            // in the last iteration, which is the max product.
+            APInt Res = N1.umul_ov(N2, HasOverflow);
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict() && !HasOverflow) {
+          // Check that leading zeros and leading ones are optimal in the
+          // result, provided there is no overflow.
+          APInt ZerosMask =
+                    APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+                OnesMask =
+                    APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+          KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+          KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+          ExactZeros.Zero.setAllBits();
+          ExactZeros.One.setAllBits();
+          ExactOnes.Zero.setAllBits();
+          ExactOnes.One.setAllBits();
+
+          ExactZeros.Zero = Exact.Zero & ZerosMask;
+          ExactZeros.One = Exact.One & ZerosMask;
+          ComputedZeros.Zero = Computed.Zero & ZerosMask;
+          ComputedZeros.One = Computed.One & ZerosMask;
+          EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+
+          ExactOnes.Zero = Exact.Zero & OnesMask;
+          ExactOnes.One = Exact.One & OnesMask;
+          ComputedOnes.Zero = Computed.Zero & OnesMask;
+          ComputedOnes.One = Computed.One & OnesMask;
+          EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+        }
+      });
+    });
+  }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+  unsigned Bits = 4;
+  using KnownUnknownPair = std::pair<int, int>;
+  SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+      {{2, 0}, {7, -1}},  // 001?, 0111
+      {{2, -1}, {10, 0}}, // 0010, 101?
+      {{9, 2}, {9, 1}},   // 1?01, 10?1
+      {{5, 1}, {3, 2}}};  // 01?1, 0?11
+  for (auto [P1, P2] : TestPairs) {
+    KnownBits Known1(Bits), Known2(Bits);
+    auto [K1, U1] = P1;
+    auto [K2, U2] = P2;
+    Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+    Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+    if (U1 > -1) {
+      Known1.Zero.setBitVal(U1, 0);
+      Known1.One.setBitVal(U1, 0);
+    }
+    if (U2 > -1) {
+      Known2.Zero.setBitVal(U2, 0);
+      Known2.One.setBitVal(U2, 0);
+    }
+    KnownBits Computed = KnownBits::mul(Known1, Known2);
+    KnownBits Exact(Bits);
+    Exact.Zero.setAllBits();
+    Exact.One.setAllBits();
+
+    ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+      ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+        APInt Res = N1 * N2;
+        Exact.One &= Res;
+        Exact.Zero &= ~Res;
+      });
+    });
+
+    // Check that the leading zeros or ones are optimal for the given examples,
+    // which overflow. It is certainly sub-optimal on other examples.
+    APInt ZerosMask =
+              APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+          OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+    KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+    KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+    ExactZeros.Zero.setAllBits();
+    ExactZeros.One.setAllBits();
+    ExactOnes.Zero.setAllBits();
+    ExactOnes.One.setAllBits();
+
+    ExactZeros.Zero = Exact.Zero & ZerosMask;
+    ExactZeros.One = Exact.One & ZerosMask;
+    ComputedZeros.Zero = Computed.Zero & ZerosMask;
+    ComputedZeros.One = Computed.One & ZerosMask;
+    EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+
+    ExactOnes.Zero = Exact.Zero & OnesMask;
+    ExactOnes.One = Exact.One & OnesMask;
+    ComputedOnes.Zero = Computed.Zero & OnesMask;
+    ComputedOnes.One = Computed.One & OnesMask;
+    EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+  }
+}
+
+TEST(KnownBitsTest, MulStress) {
+  // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+  // correct, even if not optimal.
+  for (unsigned Bits : {5, 6}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            APInt Res = N1 * N2;
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict()) {
+          EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+                                  /*CheckOptimality=*/false));
+        }
+      });
+    });
+  }
+}
 } // end anonymous namespace

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Ramkumar Ramachandra (artagnon)

Changes

Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.

-- 8< --
Based on #113051.


Full diff: https://github.com/llvm/llvm-project/pull/114211.diff

3 Files Affected:

  • (modified) llvm/lib/Support/KnownBits.cpp (+88-13)
  • (added) llvm/test/Analysis/ValueTracking/knownbits-mul.ll (+143)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+148-1)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..c2d7c776725088 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,93 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
   assert((!NoUndefSelfMultiply || LHS == RHS) &&
          "Self multiplication knownbits mismatch");
 
-  // Compute the high known-0 bits by multiplying the unsigned max of each side.
-  // Conservatively, M active bits * N active bits results in M + N bits in the
-  // result. But if we know a value is a power-of-2 for example, then this
-  // computes one more leading zero.
-  // TODO: This could be generalized to number of sign bits (negative numbers).
-  APInt UMaxLHS = LHS.getMaxValue();
-  APInt UMaxRHS = RHS.getMaxValue();
-
-  // For leading zeros in the result to be valid, the unsigned max product must
-  // fit in the bitwidth (it must not overflow).
+  // Compute the high known-0 or known-1 bits by multiplying the min and max of
+  // each side.
+  APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+        MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+        MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+        MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+  // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+  // overflow. However, if MaxProduct overflows, there is no guarantee on the
+  // MinProduct overflowing.
   bool HasOverflow;
-  APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
-  unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+  APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+        MinProduct = MinLHS * MinRHS;
+
+  bool OpsSignMatch = LHS.isNegative() == RHS.isNegative();
+  if (!OpsSignMatch) {
+    // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+    // negated to turn them into the corresponding signed-multiplication
+    // wrapped values.
+    MinProduct.negate();
+    MaxProduct.negate();
+  }
+
+  // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+  // leading zeros or ones in the result.
+  unsigned LeadZ = 0, LeadO = 0;
+  if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+    APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+          RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+    // A product of M active bits * N active bits results in M + N bits in the
+    // result. If either of the operands is a power of two, the result has one
+    // less active bit.
+    auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+      if (A.isZero() || B.isZero())
+        return 0;
+      return A.getActiveBits() + B.getActiveBits() -
+             (A.isPowerOf2() || B.isPowerOf2());
+    };
+
+    // We want to compute the number of active bits in the difference between
+    // the non-wrapped max product and non-wrapped min product, but we want to
+    // avoid camputing the non-wrapped max/min product.
+    unsigned ActiveBitsInDiff;
+    if (MinLHS.isZero() && MinRHS.isZero())
+      ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+    else
+      ActiveBitsInDiff =
+          ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+          ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+    // We uniformly handle the case where there is no max-overflow, in which
+    // case the high zeros and ones are computed optimally, and where there is,
+    // but the result shifts at most by BitWidth, in which case the high zeros
+    // and ones are not computed optimally.
+    if (!HasOverflow || ActiveBitsInDiff <= BitWidth) {
+      // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+      // and B is zero.
+      auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+        return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+      };
+
+      // If we're shifting by BitWidth, MaxProduct and MinProduct are swapped.
+      bool MinMaxSwap = ActiveBitsInDiff == BitWidth;
+      if (MinMaxSwap)
+        std::swap(MaxProduct, MinProduct);
+
+      if (OpsSignMatch != MinMaxSwap) {
+        // Normally, this is the case for when the signs of LHS and RHS match,
+        // and the else branch is for when the signs mismatch. However, if min
+        // and max were swapped, we need to invert these cases.
+        if (UgtCheckCorner(MaxProduct, MinProduct)) {
+          // Normally, when the signs of LHS and RHS match, we can safely set
+          // leading zeros of the result. However, if both MaxProduct and
+          // MinProduct are negative, we can also set the leading ones.
+          LeadZ = MaxProduct.countLeadingZeros();
+          LeadO = (MaxProduct & MinProduct).countLeadingOnes();
+        }
+      } else if (UgtCheckCorner(MinProduct, MaxProduct)) {
+        // Normally, when the signs of LHS and RHS mismatch, we can safely set
+        // leading ones of the result. However, if both MaxProduct and
+        // MinProduct are non-negative, we can also set the leading zeros.
+        LeadO = MaxProduct.countLeadingOnes();
+        LeadZ = (MaxProduct | MinProduct).countLeadingZeros();
+      }
+    }
+  }
 
   // The result of the bottom bits of an integer multiply can be
   // inferred by looking at the bottom bits of both operands and
@@ -873,8 +947,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
 
   KnownBits Res(BitWidth);
   Res.Zero.setHighBits(LeadZ);
+  Res.One.setHighBits(LeadO);
   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
-  Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+  Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
 
   // If we're self-multiplying then bit[1] is guaranteed to be zero.
   if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..37526c67f0d9e1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 2
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %x.notsmin = or i8 %x, 3
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x.notsmin, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 6
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 -16
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %y.nonzero = or i8 %y, 1
+  %mul = mul i8 %x, %y.nonzero
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %x = and i8 %xx, 2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 8
+  ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT:    [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -2
+  %y = and i8 %yy, 4
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, -16
+  ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT:    [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT:    [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or i8 %xx, -4
+  %y = or i8 %yy, -2
+  %mul = mul i8 %x, %y
+  %r = and i8 %mul, 16
+  ret i8 %r
+}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
   }
 }
 
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
   for (unsigned Bits : {1, 4}) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
       ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) {
   }
 }
 
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+  for (unsigned Bits : {1, 4}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits), WideExact(2 * Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        bool HasOverflow;
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            // The final value of HasOverflow corresponds to the multiplication
+            // in the last iteration, which is the max product.
+            APInt Res = N1.umul_ov(N2, HasOverflow);
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict() && !HasOverflow) {
+          // Check that leading zeros and leading ones are optimal in the
+          // result, provided there is no overflow.
+          APInt ZerosMask =
+                    APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+                OnesMask =
+                    APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+          KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+          KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+          ExactZeros.Zero.setAllBits();
+          ExactZeros.One.setAllBits();
+          ExactOnes.Zero.setAllBits();
+          ExactOnes.One.setAllBits();
+
+          ExactZeros.Zero = Exact.Zero & ZerosMask;
+          ExactZeros.One = Exact.One & ZerosMask;
+          ComputedZeros.Zero = Computed.Zero & ZerosMask;
+          ComputedZeros.One = Computed.One & ZerosMask;
+          EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+
+          ExactOnes.Zero = Exact.Zero & OnesMask;
+          ExactOnes.One = Exact.One & OnesMask;
+          ComputedOnes.Zero = Computed.Zero & OnesMask;
+          ComputedOnes.One = Computed.One & OnesMask;
+          EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+                                  {Known1, Known2},
+                                  /*CheckOptimality=*/true));
+        }
+      });
+    });
+  }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+  unsigned Bits = 4;
+  using KnownUnknownPair = std::pair<int, int>;
+  SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+      {{2, 0}, {7, -1}},  // 001?, 0111
+      {{2, -1}, {10, 0}}, // 0010, 101?
+      {{9, 2}, {9, 1}},   // 1?01, 10?1
+      {{5, 1}, {3, 2}}};  // 01?1, 0?11
+  for (auto [P1, P2] : TestPairs) {
+    KnownBits Known1(Bits), Known2(Bits);
+    auto [K1, U1] = P1;
+    auto [K2, U2] = P2;
+    Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+    Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+    if (U1 > -1) {
+      Known1.Zero.setBitVal(U1, 0);
+      Known1.One.setBitVal(U1, 0);
+    }
+    if (U2 > -1) {
+      Known2.Zero.setBitVal(U2, 0);
+      Known2.One.setBitVal(U2, 0);
+    }
+    KnownBits Computed = KnownBits::mul(Known1, Known2);
+    KnownBits Exact(Bits);
+    Exact.Zero.setAllBits();
+    Exact.One.setAllBits();
+
+    ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+      ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+        APInt Res = N1 * N2;
+        Exact.One &= Res;
+        Exact.Zero &= ~Res;
+      });
+    });
+
+    // Check that the leading zeros or ones are optimal for the given examples,
+    // which overflow. It is certainly sub-optimal on other examples.
+    APInt ZerosMask =
+              APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+          OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+    KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+    KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+    ExactZeros.Zero.setAllBits();
+    ExactZeros.One.setAllBits();
+    ExactOnes.Zero.setAllBits();
+    ExactOnes.One.setAllBits();
+
+    ExactZeros.Zero = Exact.Zero & ZerosMask;
+    ExactZeros.One = Exact.One & ZerosMask;
+    ComputedZeros.Zero = Computed.Zero & ZerosMask;
+    ComputedZeros.One = Computed.One & ZerosMask;
+    EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+
+    ExactOnes.Zero = Exact.Zero & OnesMask;
+    ExactOnes.One = Exact.One & OnesMask;
+    ComputedOnes.Zero = Computed.Zero & OnesMask;
+    ComputedOnes.One = Computed.One & OnesMask;
+    EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
+                            /*CheckOptimality=*/true));
+  }
+}
+
+TEST(KnownBitsTest, MulStress) {
+  // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+  // correct, even if not optimal.
+  for (unsigned Bits : {5, 6}) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+      ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+        KnownBits Computed = KnownBits::mul(Known1, Known2);
+        KnownBits Exact(Bits);
+        Exact.Zero.setAllBits();
+        Exact.One.setAllBits();
+
+        ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+          ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+            APInt Res = N1 * N2;
+            Exact.One &= Res;
+            Exact.Zero &= ~Res;
+          });
+        });
+
+        if (!Exact.hasConflict()) {
+          EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+                                  /*CheckOptimality=*/false));
+        }
+      });
+    });
+  }
+}
 } // end anonymous namespace

@artagnon artagnon requested a review from RKSimon October 30, 2024 11:44
@artagnon artagnon force-pushed the knownbits-mul-overflow branch from 4090b2c to eef289a Compare October 30, 2024 12:26
Make the non-overflow case of KnownBits::mul optimal, and smoothly
generalize it to the case when overflow occurs by relying on min-product
in addition to max-product, noting that it cannot possibly be optimal
unless we also look at the bits in between min-product and max-product.
@artagnon artagnon force-pushed the knownbits-mul-overflow branch from eef289a to 04f1601 Compare October 30, 2024 12:41
@artagnon artagnon marked this pull request as draft October 30, 2024 15:09
@artagnon artagnon marked this pull request as ready for review October 31, 2024 22:46
@artagnon artagnon force-pushed the knownbits-mul-overflow branch 2 times, most recently from b127b93 to dd73f6c Compare October 31, 2024 22:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants