-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[APInt] Add a simpler overload of multiplicativeInverse #87610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current APInt::multiplicativeInverse takes a modulus which can be any value, but all in-tree callers use a power of two. Moreover, most callers want to use two to the power of the width of an existing APInt, which is awkward because 2^N is not representable as an N-bit APInt. Add a new overload of multiplicativeInverse which implicitly uses 2^BitWidth as the modulus.
@llvm/pr-subscribers-llvm-adt @llvm/pr-subscribers-llvm-analysis Author: Jay Foad (jayfoad) ChangesThe current APInt::multiplicativeInverse takes a modulus which can be Add a new overload of multiplicativeInverse which implicitly uses Full diff: https://github.com/llvm/llvm-project/pull/87610.diff 6 Files Affected:
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index b9b39f3b9dfbc4..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1743,6 +1743,9 @@ class [[nodiscard]] APInt {
/// \returns the multiplicative inverse for a given modulo.
APInt multiplicativeInverse(const APInt &modulo) const;
+ /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
+ APInt multiplicativeInverse() const;
+
/// @}
/// \name Building-block Operations for APInt and APFloat
/// @{
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 515b9d0744f6e3..e030b9fc7dac4f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -944,10 +944,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
// Calculate the multiplicative inverse of K! / 2^T;
// this multiplication factor will perform the exact division by
// K! / 2^T.
- APInt Mod = APInt::getSignedMinValue(W+1);
- APInt MultiplyFactor = OddFactorial.zext(W+1);
- MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
- MultiplyFactor = MultiplyFactor.trunc(W);
+ APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
// Calculate the product, at width T+W
IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
@@ -10086,10 +10083,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
// If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
// (N / D) in general. The inverse itself always fits into BW bits, though,
// so we immediately truncate it.
- APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
- APInt Mod(BW + 1, 0);
- Mod.setBit(BW - Mult2); // Mod = N / D
- APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
+ APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
+ APInt I = AD.multiplicativeInverse().zext(BW);
// 4. Compute the minimum unsigned root of the equation:
// I * (B / D) mod (N / D)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 062132c8304b06..e7a46700962157 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5202,9 +5202,7 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
// Calculate the multiplicative inverse modulo BW.
// 2^W requires W + 1 bits, so we have to extend and then truncate.
unsigned W = Divisor.getBitWidth();
- APInt Factor = Divisor.zext(W + 1)
- .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
- .trunc(W);
+ APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
return true;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 8bb9541bfe1027..8aff2a081024c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -6050,11 +6050,7 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
Divisor.ashrInPlace(Shift);
UseSRA = true;
}
- // Calculate the multiplicative inverse, using Newton's method.
- APInt t;
- APInt Factor = Divisor;
- while ((t = Divisor * Factor) != 1)
- Factor *= APInt(Divisor.getBitWidth(), 2) - t;
+ APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
Factors.push_back(DAG.getConstant(Factor, dl, SVT));
return true;
@@ -6643,10 +6639,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
// P = inv(D0, 2^W)
// 2^W requires W + 1 bits, so we have to extend and then truncate.
unsigned W = D.getBitWidth();
- APInt P = D0.zext(W + 1)
- .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
- .trunc(W);
- assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
+ APInt P = D0.multiplicativeInverse();
assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
// Q = floor((2^W - 1) u/ D)
@@ -6901,10 +6894,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
// P = inv(D0, 2^W)
// 2^W requires W + 1 bits, so we have to extend and then truncate.
unsigned W = D.getBitWidth();
- APInt P = D0.zext(W + 1)
- .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
- .trunc(W);
- assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
+ APInt P = D0.multiplicativeInverse();
assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
// A = floor((2^(W - 1) - 1) / D0) & -2^K
@@ -7630,7 +7620,7 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
//
// For division, we can compute the remainder using the algorithm described
// above, subtract it from the dividend to get an exact multiple of Constant.
-// Then multiply that extact multiply by the multiplicative inverse modulo
+// Then multiply that exact multiply by the multiplicative inverse modulo
// (1 << (BitWidth / 2)) to get the quotient.
// If Constant is even, we can shift right the dividend and the divisor by the
@@ -7765,10 +7755,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
// Multiply by the multiplicative inverse of the divisor modulo
// (1 << BitWidth).
- APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
- APInt MulFactor = Divisor.zext(BitWidth + 1);
- MulFactor = MulFactor.multiplicativeInverse(Mod);
- MulFactor = MulFactor.trunc(BitWidth);
+ APInt MulFactor = Divisor.multiplicativeInverse();
SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
DAG.getConstant(MulFactor, dl, VT));
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index c20609748dc97c..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1289,6 +1289,19 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const {
return std::move(t[i]);
}
+/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
+APInt APInt::multiplicativeInverse() const {
+ assert((*this)[0] &&
+ "multiplicative inverse is only defined for odd numbers!");
+
+ // Use Newton's method.
+ APInt Factor = *this;
+ APInt T;
+ while (!(T = *this * Factor).isOne())
+ Factor *= 2 - T;
+ return Factor;
+}
+
/// Implementation of Knuth's Algorithm D (Division of nonnegative integers)
/// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The
/// variables here have the same names as in the algorithm. Comments explain
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index d5ef63e38e2790..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3257,9 +3257,10 @@ TEST(APIntTest, MultiplicativeInverseExaustive) {
.multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
.trunc(BitWidth);
APInt One = V * MulInv;
- if (!V.isZero() && V.countr_zero() == 0) {
+ if (V[0]) {
// Multiplicative inverse exists for all odd numbers.
EXPECT_TRUE(One.isOne());
+ EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
} else {
// Multiplicative inverse does not exist for even numbers (and 0).
EXPECT_TRUE(MulInv.isZero());
|
The old multiplicativeInverse can be removed after this, but I would do it as a separate commit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
The current APInt::multiplicativeInverse takes a modulus which can be
any value, but all in-tree callers use a power of two. Moreover, most
callers want to use two to the power of the width of an existing APInt,
which is awkward because 2^N is not representable as an N-bit APInt.
Add a new overload of multiplicativeInverse which implicitly uses
2^BitWidth as the modulus.