-
Notifications
You must be signed in to change notification settings - Fork 14.3k
KnownBits: generalize high-bits of mul to overflows #114211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
KnownBits::mul suffers from the deficiency that it doesn't account for signed inputs. Fix it by refining known leading zeros when both inputs are signed, and setting known leading ones when one of the inputs is signed. The strategy we've used is to still use umul_ov, after adjusting for signed inputs, and setting known leading ones from the negation of the result, when it is known to be negative, noting that a possibly-zero result is a special case.
@llvm/pr-subscribers-llvm-support Author: Ramkumar Ramachandra (artagnon) ChangesMake the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product. -- 8< -- Full diff: https://github.com/llvm/llvm-project/pull/114211.diff 3 Files Affected:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..c2d7c776725088 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,93 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
assert((!NoUndefSelfMultiply || LHS == RHS) &&
"Self multiplication knownbits mismatch");
- // Compute the high known-0 bits by multiplying the unsigned max of each side.
- // Conservatively, M active bits * N active bits results in M + N bits in the
- // result. But if we know a value is a power-of-2 for example, then this
- // computes one more leading zero.
- // TODO: This could be generalized to number of sign bits (negative numbers).
- APInt UMaxLHS = LHS.getMaxValue();
- APInt UMaxRHS = RHS.getMaxValue();
-
- // For leading zeros in the result to be valid, the unsigned max product must
- // fit in the bitwidth (it must not overflow).
+ // Compute the high known-0 or known-1 bits by multiplying the min and max of
+ // each side.
+ APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+ MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+ MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+ MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+ // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+ // overflow. However, if MaxProduct overflows, there is no guarantee on the
+ // MinProduct overflowing.
bool HasOverflow;
- APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
- unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+ APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+ MinProduct = MinLHS * MinRHS;
+
+ bool OpsSignMatch = LHS.isNegative() == RHS.isNegative();
+ if (!OpsSignMatch) {
+ // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+ // negated to turn them into the corresponding signed-multiplication
+ // wrapped values.
+ MinProduct.negate();
+ MaxProduct.negate();
+ }
+
+ // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+ // leading zeros or ones in the result.
+ unsigned LeadZ = 0, LeadO = 0;
+ if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+ APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+ RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+ // A product of M active bits * N active bits results in M + N bits in the
+ // result. If either of the operands is a power of two, the result has one
+ // less active bit.
+ auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+ if (A.isZero() || B.isZero())
+ return 0;
+ return A.getActiveBits() + B.getActiveBits() -
+ (A.isPowerOf2() || B.isPowerOf2());
+ };
+
+ // We want to compute the number of active bits in the difference between
+ // the non-wrapped max product and non-wrapped min product, but we want to
+ // avoid camputing the non-wrapped max/min product.
+ unsigned ActiveBitsInDiff;
+ if (MinLHS.isZero() && MinRHS.isZero())
+ ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+ else
+ ActiveBitsInDiff =
+ ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+ ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+ // We uniformly handle the case where there is no max-overflow, in which
+ // case the high zeros and ones are computed optimally, and where there is,
+ // but the result shifts at most by BitWidth, in which case the high zeros
+ // and ones are not computed optimally.
+ if (!HasOverflow || ActiveBitsInDiff <= BitWidth) {
+ // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+ // and B is zero.
+ auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+ return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+ };
+
+ // If we're shifting by BitWidth, MaxProduct and MinProduct are swapped.
+ bool MinMaxSwap = ActiveBitsInDiff == BitWidth;
+ if (MinMaxSwap)
+ std::swap(MaxProduct, MinProduct);
+
+ if (OpsSignMatch != MinMaxSwap) {
+ // Normally, this is the case for when the signs of LHS and RHS match,
+ // and the else branch is for when the signs mismatch. However, if min
+ // and max were swapped, we need to invert these cases.
+ if (UgtCheckCorner(MaxProduct, MinProduct)) {
+ // Normally, when the signs of LHS and RHS match, we can safely set
+ // leading zeros of the result. However, if both MaxProduct and
+ // MinProduct are negative, we can also set the leading ones.
+ LeadZ = MaxProduct.countLeadingZeros();
+ LeadO = (MaxProduct & MinProduct).countLeadingOnes();
+ }
+ } else if (UgtCheckCorner(MinProduct, MaxProduct)) {
+ // Normally, when the signs of LHS and RHS mismatch, we can safely set
+ // leading ones of the result. However, if both MaxProduct and
+ // MinProduct are non-negative, we can also set the leading zeros.
+ LeadO = MaxProduct.countLeadingOnes();
+ LeadZ = (MaxProduct | MinProduct).countLeadingZeros();
+ }
+ }
+ }
// The result of the bottom bits of an integer multiply can be
// inferred by looking at the bottom bits of both operands and
@@ -873,8 +947,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
KnownBits Res(BitWidth);
Res.Zero.setHighBits(LeadZ);
+ Res.One.setHighBits(LeadO);
Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
- Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+ Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
// If we're self-multiplying then bit[1] is guaranteed to be zero.
if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..37526c67f0d9e1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 2
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT: [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %x.notsmin = or i8 %x, 3
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x.notsmin, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 -16
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %y.nonzero = or i8 %y, 1
+ %mul = mul i8 %x, %y.nonzero
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT: ret i8 [[MUL]]
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 8
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 16
+ ret i8 %r
+}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
}
}
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
for (unsigned Bits : {1, 4}) {
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) {
}
}
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+ for (unsigned Bits : {1, 4}) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits), WideExact(2 * Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ bool HasOverflow;
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ // The final value of HasOverflow corresponds to the multiplication
+ // in the last iteration, which is the max product.
+ APInt Res = N1.umul_ov(N2, HasOverflow);
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ if (!Exact.hasConflict() && !HasOverflow) {
+ // Check that leading zeros and leading ones are optimal in the
+ // result, provided there is no overflow.
+ APInt ZerosMask =
+ APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+ OnesMask =
+ APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+ KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+ KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+ ExactZeros.Zero.setAllBits();
+ ExactZeros.One.setAllBits();
+ ExactOnes.Zero.setAllBits();
+ ExactOnes.One.setAllBits();
+
+ ExactZeros.Zero = Exact.Zero & ZerosMask;
+ ExactZeros.One = Exact.One & ZerosMask;
+ ComputedZeros.Zero = Computed.Zero & ZerosMask;
+ ComputedZeros.One = Computed.One & ZerosMask;
+ EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+ {Known1, Known2},
+ /*CheckOptimality=*/true));
+
+ ExactOnes.Zero = Exact.Zero & OnesMask;
+ ExactOnes.One = Exact.One & OnesMask;
+ ComputedOnes.Zero = Computed.Zero & OnesMask;
+ ComputedOnes.One = Computed.One & OnesMask;
+ EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+ {Known1, Known2},
+ /*CheckOptimality=*/true));
+ }
+ });
+ });
+ }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+ unsigned Bits = 4;
+ using KnownUnknownPair = std::pair<int, int>;
+ SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+ {{2, 0}, {7, -1}}, // 001?, 0111
+ {{2, -1}, {10, 0}}, // 0010, 101?
+ {{9, 2}, {9, 1}}, // 1?01, 10?1
+ {{5, 1}, {3, 2}}}; // 01?1, 0?11
+ for (auto [P1, P2] : TestPairs) {
+ KnownBits Known1(Bits), Known2(Bits);
+ auto [K1, U1] = P1;
+ auto [K2, U2] = P2;
+ Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+ Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+ if (U1 > -1) {
+ Known1.Zero.setBitVal(U1, 0);
+ Known1.One.setBitVal(U1, 0);
+ }
+ if (U2 > -1) {
+ Known2.Zero.setBitVal(U2, 0);
+ Known2.One.setBitVal(U2, 0);
+ }
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ APInt Res = N1 * N2;
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ // Check that the leading zeros or ones are optimal for the given examples,
+ // which overflow. It is certainly sub-optimal on other examples.
+ APInt ZerosMask =
+ APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+ OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+ KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+ KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+ ExactZeros.Zero.setAllBits();
+ ExactZeros.One.setAllBits();
+ ExactOnes.Zero.setAllBits();
+ ExactOnes.One.setAllBits();
+
+ ExactZeros.Zero = Exact.Zero & ZerosMask;
+ ExactZeros.One = Exact.One & ZerosMask;
+ ComputedZeros.Zero = Computed.Zero & ZerosMask;
+ ComputedZeros.One = Computed.One & ZerosMask;
+ EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+ /*CheckOptimality=*/true));
+
+ ExactOnes.Zero = Exact.Zero & OnesMask;
+ ExactOnes.One = Exact.One & OnesMask;
+ ComputedOnes.Zero = Computed.Zero & OnesMask;
+ ComputedOnes.One = Computed.One & OnesMask;
+ EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
+ /*CheckOptimality=*/true));
+ }
+}
+
+TEST(KnownBitsTest, MulStress) {
+ // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+ // correct, even if not optimal.
+ for (unsigned Bits : {5, 6}) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ APInt Res = N1 * N2;
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ if (!Exact.hasConflict()) {
+ EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+ /*CheckOptimality=*/false));
+ }
+ });
+ });
+ }
+}
} // end anonymous namespace
|
@llvm/pr-subscribers-llvm-analysis Author: Ramkumar Ramachandra (artagnon) ChangesMake the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product. -- 8< -- Full diff: https://github.com/llvm/llvm-project/pull/114211.diff 3 Files Affected:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..c2d7c776725088 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -796,19 +796,93 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
assert((!NoUndefSelfMultiply || LHS == RHS) &&
"Self multiplication knownbits mismatch");
- // Compute the high known-0 bits by multiplying the unsigned max of each side.
- // Conservatively, M active bits * N active bits results in M + N bits in the
- // result. But if we know a value is a power-of-2 for example, then this
- // computes one more leading zero.
- // TODO: This could be generalized to number of sign bits (negative numbers).
- APInt UMaxLHS = LHS.getMaxValue();
- APInt UMaxRHS = RHS.getMaxValue();
-
- // For leading zeros in the result to be valid, the unsigned max product must
- // fit in the bitwidth (it must not overflow).
+ // Compute the high known-0 or known-1 bits by multiplying the min and max of
+ // each side.
+ APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
+ MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
+ MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
+ MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
+
+ // If MaxProduct doesn't overflow, it implies that MinProduct also won't
+ // overflow. However, if MaxProduct overflows, there is no guarantee on the
+ // MinProduct overflowing.
bool HasOverflow;
- APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
- unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
+ APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
+ MinProduct = MinLHS * MinRHS;
+
+ bool OpsSignMatch = LHS.isNegative() == RHS.isNegative();
+ if (!OpsSignMatch) {
+ // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
+ // negated to turn them into the corresponding signed-multiplication
+ // wrapped values.
+ MinProduct.negate();
+ MaxProduct.negate();
+ }
+
+ // Unless both MinProduct and MaxProduct are the same sign, there won't be any
+ // leading zeros or ones in the result.
+ unsigned LeadZ = 0, LeadO = 0;
+ if (MinProduct.isNegative() == MaxProduct.isNegative()) {
+ APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
+ RHSUnknown = (~RHS.Zero & ~RHS.One);
+
+ // A product of M active bits * N active bits results in M + N bits in the
+ // result. If either of the operands is a power of two, the result has one
+ // less active bit.
+ auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
+ if (A.isZero() || B.isZero())
+ return 0;
+ return A.getActiveBits() + B.getActiveBits() -
+ (A.isPowerOf2() || B.isPowerOf2());
+ };
+
+ // We want to compute the number of active bits in the difference between
+ // the non-wrapped max product and non-wrapped min product, but we want to
+ // avoid camputing the non-wrapped max/min product.
+ unsigned ActiveBitsInDiff;
+ if (MinLHS.isZero() && MinRHS.isZero())
+ ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
+ else
+ ActiveBitsInDiff =
+ ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
+ ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
+
+ // We uniformly handle the case where there is no max-overflow, in which
+ // case the high zeros and ones are computed optimally, and where there is,
+ // but the result shifts at most by BitWidth, in which case the high zeros
+ // and ones are not computed optimally.
+ if (!HasOverflow || ActiveBitsInDiff <= BitWidth) {
+ // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
+ // and B is zero.
+ auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
+ return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
+ };
+
+ // If we're shifting by BitWidth, MaxProduct and MinProduct are swapped.
+ bool MinMaxSwap = ActiveBitsInDiff == BitWidth;
+ if (MinMaxSwap)
+ std::swap(MaxProduct, MinProduct);
+
+ if (OpsSignMatch != MinMaxSwap) {
+ // Normally, this is the case for when the signs of LHS and RHS match,
+ // and the else branch is for when the signs mismatch. However, if min
+ // and max were swapped, we need to invert these cases.
+ if (UgtCheckCorner(MaxProduct, MinProduct)) {
+ // Normally, when the signs of LHS and RHS match, we can safely set
+ // leading zeros of the result. However, if both MaxProduct and
+ // MinProduct are negative, we can also set the leading ones.
+ LeadZ = MaxProduct.countLeadingZeros();
+ LeadO = (MaxProduct & MinProduct).countLeadingOnes();
+ }
+ } else if (UgtCheckCorner(MinProduct, MaxProduct)) {
+ // Normally, when the signs of LHS and RHS mismatch, we can safely set
+ // leading ones of the result. However, if both MaxProduct and
+ // MinProduct are non-negative, we can also set the leading zeros.
+ LeadO = MaxProduct.countLeadingOnes();
+ LeadZ = (MaxProduct | MinProduct).countLeadingZeros();
+ }
+ }
+ }
// The result of the bottom bits of an integer multiply can be
// inferred by looking at the bottom bits of both operands and
@@ -873,8 +947,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
KnownBits Res(BitWidth);
Res.Zero.setHighBits(LeadZ);
+ Res.One.setHighBits(LeadO);
Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
- Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+ Res.One |= BottomKnown.getLoBits(ResultBitsKnown);
// If we're self-multiplying then bit[1] is guaranteed to be zero.
if (NoUndefSelfMultiply && BitWidth > 1) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
new file mode 100644
index 00000000000000..37526c67f0d9e1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @mul_low_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 2
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_partially_known(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 2
+; CHECK-NEXT: [[MUL:%.*]] = sub nsw i8 0, [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %x.notsmin = or i8 %x, 3
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x.notsmin, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_low_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 4
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 6
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 6
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 6
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 -16
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %y.nonzero = or i8 %y, 1
+ %mul = mul i8 %x, %y.nonzero
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_know3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: ret i8 0
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = and i8 [[XX]], 2
+; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]]
+; CHECK-NEXT: ret i8 [[MUL]]
+;
+ %x = and i8 %xx, 2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 8
+ ret i8 %r
+}
+
+define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown2(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], -2
+; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4
+; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], -16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -2
+ %y = and i8 %yy, 4
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, -16
+ ret i8 %r
+}
+
+; TODO: This can be reduced to zero.
+define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) {
+; CHECK-LABEL: define i8 @mul_high_bits_unknown3(
+; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) {
+; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 28
+; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 30
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or i8 %xx, -4
+ %y = or i8 %yy, -2
+ %mul = mul i8 %x, %y
+ %r = and i8 %mul, 16
+ ret i8 %r
+}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..2be2e1d093315c 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) {
}
}
-TEST(KnownBitsTest, MulExhaustive) {
+TEST(KnownBitsTest, MulLowBitsExhaustive) {
for (unsigned Bits : {1, 4}) {
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
@@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) {
}
}
+TEST(KnownBitsTest, MulHighBitsNoOverflow) {
+ for (unsigned Bits : {1, 4}) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits), WideExact(2 * Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ bool HasOverflow;
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ // The final value of HasOverflow corresponds to the multiplication
+ // in the last iteration, which is the max product.
+ APInt Res = N1.umul_ov(N2, HasOverflow);
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ if (!Exact.hasConflict() && !HasOverflow) {
+ // Check that leading zeros and leading ones are optimal in the
+ // result, provided there is no overflow.
+ APInt ZerosMask =
+ APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+ OnesMask =
+ APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+ KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+ KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+ ExactZeros.Zero.setAllBits();
+ ExactZeros.One.setAllBits();
+ ExactOnes.Zero.setAllBits();
+ ExactOnes.One.setAllBits();
+
+ ExactZeros.Zero = Exact.Zero & ZerosMask;
+ ExactZeros.One = Exact.One & ZerosMask;
+ ComputedZeros.Zero = Computed.Zero & ZerosMask;
+ ComputedZeros.One = Computed.One & ZerosMask;
+ EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
+ {Known1, Known2},
+ /*CheckOptimality=*/true));
+
+ ExactOnes.Zero = Exact.Zero & OnesMask;
+ ExactOnes.One = Exact.One & OnesMask;
+ ComputedOnes.Zero = Computed.Zero & OnesMask;
+ ComputedOnes.One = Computed.One & OnesMask;
+ EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
+ {Known1, Known2},
+ /*CheckOptimality=*/true));
+ }
+ });
+ });
+ }
+}
+
+TEST(KnownBitsTest, MulHighBitsOverflow) {
+ unsigned Bits = 4;
+ using KnownUnknownPair = std::pair<int, int>;
+ SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
+ {{2, 0}, {7, -1}}, // 001?, 0111
+ {{2, -1}, {10, 0}}, // 0010, 101?
+ {{9, 2}, {9, 1}}, // 1?01, 10?1
+ {{5, 1}, {3, 2}}}; // 01?1, 0?11
+ for (auto [P1, P2] : TestPairs) {
+ KnownBits Known1(Bits), Known2(Bits);
+ auto [K1, U1] = P1;
+ auto [K2, U2] = P2;
+ Known1 = KnownBits::makeConstant(APInt(Bits, K1));
+ Known2 = KnownBits::makeConstant(APInt(Bits, K2));
+ if (U1 > -1) {
+ Known1.Zero.setBitVal(U1, 0);
+ Known1.One.setBitVal(U1, 0);
+ }
+ if (U2 > -1) {
+ Known2.Zero.setBitVal(U2, 0);
+ Known2.One.setBitVal(U2, 0);
+ }
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ APInt Res = N1 * N2;
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ // Check that the leading zeros or ones are optimal for the given examples,
+ // which overflow. It is certainly sub-optimal on other examples.
+ APInt ZerosMask =
+ APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
+ OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
+
+ KnownBits ExactZeros(Bits), ComputedZeros(Bits);
+ KnownBits ExactOnes(Bits), ComputedOnes(Bits);
+ ExactZeros.Zero.setAllBits();
+ ExactZeros.One.setAllBits();
+ ExactOnes.Zero.setAllBits();
+ ExactOnes.One.setAllBits();
+
+ ExactZeros.Zero = Exact.Zero & ZerosMask;
+ ExactZeros.One = Exact.One & ZerosMask;
+ ComputedZeros.Zero = Computed.Zero & ZerosMask;
+ ComputedZeros.One = Computed.One & ZerosMask;
+ EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
+ /*CheckOptimality=*/true));
+
+ ExactOnes.Zero = Exact.Zero & OnesMask;
+ ExactOnes.One = Exact.One & OnesMask;
+ ComputedOnes.Zero = Computed.Zero & OnesMask;
+ ComputedOnes.One = Computed.One & OnesMask;
+ EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
+ /*CheckOptimality=*/true));
+ }
+}
+
+TEST(KnownBitsTest, MulStress) {
+ // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
+ // correct, even if not optimal.
+ for (unsigned Bits : {5, 6}) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+ ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+ KnownBits Computed = KnownBits::mul(Known1, Known2);
+ KnownBits Exact(Bits);
+ Exact.Zero.setAllBits();
+ Exact.One.setAllBits();
+
+ ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+ ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+ APInt Res = N1 * N2;
+ Exact.One &= Res;
+ Exact.Zero &= ~Res;
+ });
+ });
+
+ if (!Exact.hasConflict()) {
+ EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
+ /*CheckOptimality=*/false));
+ }
+ });
+ });
+ }
+}
} // end anonymous namespace
|
4090b2c
to
eef289a
Compare
Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.
eef289a
to
04f1601
Compare
b127b93
to
dd73f6c
Compare
dd73f6c
to
ae92a27
Compare
Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.
-- 8< --
Based on #113051.