Skip to content

[APInt] Add a simpler overload of multiplicativeInverse #87610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,9 @@ class [[nodiscard]] APInt {
/// \returns the multiplicative inverse for a given modulo.
APInt multiplicativeInverse(const APInt &modulo) const;

/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
APInt multiplicativeInverse() const;

/// @}
/// \name Building-block Operations for APInt and APFloat
/// @{
Expand Down
11 changes: 3 additions & 8 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,10 +944,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
// Calculate the multiplicative inverse of K! / 2^T;
// this multiplication factor will perform the exact division by
// K! / 2^T.
APInt Mod = APInt::getSignedMinValue(W+1);
APInt MultiplyFactor = OddFactorial.zext(W+1);
MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
MultiplyFactor = MultiplyFactor.trunc(W);
APInt MultiplyFactor = OddFactorial.multiplicativeInverse();

// Calculate the product, at width T+W
IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
Expand Down Expand Up @@ -10086,10 +10083,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
// If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
// (N / D) in general. The inverse itself always fits into BW bits, though,
// so we immediately truncate it.
APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
APInt Mod(BW + 1, 0);
Mod.setBit(BW - Mult2); // Mod = N / D
APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
APInt I = AD.multiplicativeInverse().zext(BW);

// 4. Compute the minimum unsigned root of the equation:
// I * (B / D) mod (N / D)
Expand Down
5 changes: 1 addition & 4 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5201,10 +5201,7 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {

// Calculate the multiplicative inverse modulo BW.
// 2^W requires W + 1 bits, so we have to extend and then truncate.
unsigned W = Divisor.getBitWidth();
APInt Factor = Divisor.zext(W + 1)
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
.trunc(W);
APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
return true;
Expand Down
23 changes: 5 additions & 18 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6050,11 +6050,7 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
Divisor.ashrInPlace(Shift);
UseSRA = true;
}
// Calculate the multiplicative inverse, using Newton's method.
APInt t;
APInt Factor = Divisor;
while ((t = Divisor * Factor) != 1)
Factor *= APInt(Divisor.getBitWidth(), 2) - t;
APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
Factors.push_back(DAG.getConstant(Factor, dl, SVT));
return true;
Expand Down Expand Up @@ -6643,10 +6639,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
// P = inv(D0, 2^W)
// 2^W requires W + 1 bits, so we have to extend and then truncate.
unsigned W = D.getBitWidth();
APInt P = D0.zext(W + 1)
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
.trunc(W);
assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
APInt P = D0.multiplicativeInverse();
assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");

// Q = floor((2^W - 1) u/ D)
Expand Down Expand Up @@ -6901,10 +6894,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
// P = inv(D0, 2^W)
// 2^W requires W + 1 bits, so we have to extend and then truncate.
unsigned W = D.getBitWidth();
APInt P = D0.zext(W + 1)
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
.trunc(W);
assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
APInt P = D0.multiplicativeInverse();
assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");

// A = floor((2^(W - 1) - 1) / D0) & -2^K
Expand Down Expand Up @@ -7630,7 +7620,7 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
//
// For division, we can compute the remainder using the algorithm described
// above, subtract it from the dividend to get an exact multiple of Constant.
// Then multiply that extact multiply by the multiplicative inverse modulo
// Then multiply that exact multiply by the multiplicative inverse modulo
// (1 << (BitWidth / 2)) to get the quotient.

// If Constant is even, we can shift right the dividend and the divisor by the
Expand Down Expand Up @@ -7765,10 +7755,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,

// Multiply by the multiplicative inverse of the divisor modulo
// (1 << BitWidth).
APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
APInt MulFactor = Divisor.zext(BitWidth + 1);
MulFactor = MulFactor.multiplicativeInverse(Mod);
MulFactor = MulFactor.trunc(BitWidth);
APInt MulFactor = Divisor.multiplicativeInverse();

SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
DAG.getConstant(MulFactor, dl, VT));
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Support/APInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,19 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const {
return std::move(t[i]);
}

/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
APInt APInt::multiplicativeInverse() const {
assert((*this)[0] &&
"multiplicative inverse is only defined for odd numbers!");

// Use Newton's method.
APInt Factor = *this;
APInt T;
while (!(T = *this * Factor).isOne())
Factor *= 2 - T;
return Factor;
}

/// Implementation of Knuth's Algorithm D (Division of nonnegative integers)
/// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The
/// variables here have the same names as in the algorithm. Comments explain
Expand Down
3 changes: 2 additions & 1 deletion llvm/unittests/ADT/APIntTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3257,9 +3257,10 @@ TEST(APIntTest, MultiplicativeInverseExaustive) {
.multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
.trunc(BitWidth);
APInt One = V * MulInv;
if (!V.isZero() && V.countr_zero() == 0) {
if (V[0]) {
// Multiplicative inverse exists for all odd numbers.
EXPECT_TRUE(One.isOne());
EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
} else {
// Multiplicative inverse does not exist for even numbers (and 0).
EXPECT_TRUE(MulInv.isZero());
Expand Down