Skip to content

Commit 1b76120

Browse files
authored
[APInt] Add a simpler overload of multiplicativeInverse (llvm#87610)
The current APInt::multiplicativeInverse takes a modulus which can be any value, but all in-tree callers use a power of two. Moreover, most callers want to use two to the power of the width of an existing APInt, which is awkward because 2^N is not representable as an N-bit APInt. Add a new overload of multiplicativeInverse which implicitly uses 2^BitWidth as the modulus.
1 parent 51f1cb5 commit 1b76120

File tree

6 files changed

+27
-31
lines changed

6 files changed

+27
-31
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,6 +1743,9 @@ class [[nodiscard]] APInt {
17431743
/// \returns the multiplicative inverse for a given modulo.
17441744
APInt multiplicativeInverse(const APInt &modulo) const;
17451745

1746+
/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
1747+
APInt multiplicativeInverse() const;
1748+
17461749
/// @}
17471750
/// \name Building-block Operations for APInt and APFloat
17481751
/// @{

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -944,10 +944,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
944944
// Calculate the multiplicative inverse of K! / 2^T;
945945
// this multiplication factor will perform the exact division by
946946
// K! / 2^T.
947-
APInt Mod = APInt::getSignedMinValue(W+1);
948-
APInt MultiplyFactor = OddFactorial.zext(W+1);
949-
MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
950-
MultiplyFactor = MultiplyFactor.trunc(W);
947+
APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
951948

952949
// Calculate the product, at width T+W
953950
IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
@@ -10086,10 +10083,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1008610083
// If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
1008710084
// (N / D) in general. The inverse itself always fits into BW bits, though,
1008810085
// so we immediately truncate it.
10089-
APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
10090-
APInt Mod(BW + 1, 0);
10091-
Mod.setBit(BW - Mult2); // Mod = N / D
10092-
APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
10086+
APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10087+
APInt I = AD.multiplicativeInverse().zext(BW);
1009310088

1009410089
// 4. Compute the minimum unsigned root of the equation:
1009510090
// I * (B / D) mod (N / D)

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5201,10 +5201,7 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
52015201

52025202
// Calculate the multiplicative inverse modulo BW.
52035203
// 2^W requires W + 1 bits, so we have to extend and then truncate.
5204-
unsigned W = Divisor.getBitWidth();
5205-
APInt Factor = Divisor.zext(W + 1)
5206-
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
5207-
.trunc(W);
5204+
APInt Factor = Divisor.multiplicativeInverse();
52085205
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
52095206
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
52105207
return true;

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6071,11 +6071,7 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
60716071
Divisor.ashrInPlace(Shift);
60726072
UseSRA = true;
60736073
}
6074-
// Calculate the multiplicative inverse, using Newton's method.
6075-
APInt t;
6076-
APInt Factor = Divisor;
6077-
while ((t = Divisor * Factor) != 1)
6078-
Factor *= APInt(Divisor.getBitWidth(), 2) - t;
6074+
APInt Factor = Divisor.multiplicativeInverse();
60796075
Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
60806076
Factors.push_back(DAG.getConstant(Factor, dl, SVT));
60816077
return true;
@@ -6664,10 +6660,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
66646660
// P = inv(D0, 2^W)
66656661
// 2^W requires W + 1 bits, so we have to extend and then truncate.
66666662
unsigned W = D.getBitWidth();
6667-
APInt P = D0.zext(W + 1)
6668-
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
6669-
.trunc(W);
6670-
assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
6663+
APInt P = D0.multiplicativeInverse();
66716664
assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
66726665

66736666
// Q = floor((2^W - 1) u/ D)
@@ -6922,10 +6915,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
69226915
// P = inv(D0, 2^W)
69236916
// 2^W requires W + 1 bits, so we have to extend and then truncate.
69246917
unsigned W = D.getBitWidth();
6925-
APInt P = D0.zext(W + 1)
6926-
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
6927-
.trunc(W);
6928-
assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
6918+
APInt P = D0.multiplicativeInverse();
69296919
assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
69306920

69316921
// A = floor((2^(W - 1) - 1) / D0) & -2^K
@@ -7651,7 +7641,7 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
76517641
//
76527642
// For division, we can compute the remainder using the algorithm described
76537643
// above, subtract it from the dividend to get an exact multiple of Constant.
7654-
// Then multiply that extact multiply by the multiplicative inverse modulo
7644+
// Then multiply that exact multiply by the multiplicative inverse modulo
76557645
// (1 << (BitWidth / 2)) to get the quotient.
76567646

76577647
// If Constant is even, we can shift right the dividend and the divisor by the
@@ -7786,10 +7776,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
77867776

77877777
// Multiply by the multiplicative inverse of the divisor modulo
77887778
// (1 << BitWidth).
7789-
APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
7790-
APInt MulFactor = Divisor.zext(BitWidth + 1);
7791-
MulFactor = MulFactor.multiplicativeInverse(Mod);
7792-
MulFactor = MulFactor.trunc(BitWidth);
7779+
APInt MulFactor = Divisor.multiplicativeInverse();
77937780

77947781
SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
77957782
DAG.getConstant(MulFactor, dl, VT));

llvm/lib/Support/APInt.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,19 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const {
12891289
return std::move(t[i]);
12901290
}
12911291

1292+
/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
1293+
APInt APInt::multiplicativeInverse() const {
1294+
assert((*this)[0] &&
1295+
"multiplicative inverse is only defined for odd numbers!");
1296+
1297+
// Use Newton's method.
1298+
APInt Factor = *this;
1299+
APInt T;
1300+
while (!(T = *this * Factor).isOne())
1301+
Factor *= 2 - T;
1302+
return Factor;
1303+
}
1304+
12921305
/// Implementation of Knuth's Algorithm D (Division of nonnegative integers)
12931306
/// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The
12941307
/// variables here have the same names as in the algorithm. Comments explain

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3257,9 +3257,10 @@ TEST(APIntTest, MultiplicativeInverseExaustive) {
32573257
.multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
32583258
.trunc(BitWidth);
32593259
APInt One = V * MulInv;
3260-
if (!V.isZero() && V.countr_zero() == 0) {
3260+
if (V[0]) {
32613261
// Multiplicative inverse exists for all odd numbers.
32623262
EXPECT_TRUE(One.isOne());
3263+
EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
32633264
} else {
32643265
// Multiplicative inverse does not exist for even numbers (and 0).
32653266
EXPECT_TRUE(MulInv.isZero());

0 commit comments

Comments
 (0)