Skip to content

[APInt] Add a simpler overload of multiplicativeInverse #87610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2024

Conversation

jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Apr 4, 2024

The current APInt::multiplicativeInverse takes a modulus which can be
any value, but all in-tree callers use a power of two. Moreover, most
callers want to use two to the power of the width of an existing APInt,
which is awkward because 2^N is not representable as an N-bit APInt.

Add a new overload of multiplicativeInverse which implicitly uses
2^BitWidth as the modulus.

jayfoad added 2 commits April 4, 2024 09:30
The current APInt::multiplicativeInverse takes a modulus which can be
any value, but all in-tree callers use a power of two. Moreover, most
callers want to use two to the power of the width of an existing APInt,
which is awkward because 2^N is not representable as an N-bit APInt.

Add a new overload of multiplicativeInverse which implicitly uses
2^BitWidth as the modulus.
@jayfoad jayfoad requested a review from nikic as a code owner April 4, 2024 08:46
@llvmbot llvmbot added llvm:globalisel llvm:support llvm:SelectionDAG SelectionDAGISel as well llvm:analysis Includes value tracking, cost tables and constant folding llvm:adt labels Apr 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2024

@llvm/pr-subscribers-llvm-adt
@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-llvm-support
@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-llvm-analysis

Author: Jay Foad (jayfoad)

Changes

The current APInt::multiplicativeInverse takes a modulus which can be
any value, but all in-tree callers use a power of two. Moreover, most
callers want to use two to the power of the width of an existing APInt,
which is awkward because 2^N is not representable as an N-bit APInt.

Add a new overload of multiplicativeInverse which implicitly uses
2^BitWidth as the modulus.


Full diff: https://github.com/llvm/llvm-project/pull/87610.diff

6 Files Affected:

  • (modified) llvm/include/llvm/ADT/APInt.h (+3)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+3-8)
  • (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp (+1-3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+5-18)
  • (modified) llvm/lib/Support/APInt.cpp (+13)
  • (modified) llvm/unittests/ADT/APIntTest.cpp (+2-1)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index b9b39f3b9dfbc4..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1743,6 +1743,9 @@ class [[nodiscard]] APInt {
   /// \returns the multiplicative inverse for a given modulo.
   APInt multiplicativeInverse(const APInt &modulo) const;
 
+  /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
+  APInt multiplicativeInverse() const;
+
   /// @}
   /// \name Building-block Operations for APInt and APFloat
   /// @{
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 515b9d0744f6e3..e030b9fc7dac4f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -944,10 +944,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
   // Calculate the multiplicative inverse of K! / 2^T;
   // this multiplication factor will perform the exact division by
   // K! / 2^T.
-  APInt Mod = APInt::getSignedMinValue(W+1);
-  APInt MultiplyFactor = OddFactorial.zext(W+1);
-  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
-  MultiplyFactor = MultiplyFactor.trunc(W);
+  APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
 
   // Calculate the product, at width T+W
   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
@@ -10086,10 +10083,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
   // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
   // (N / D) in general. The inverse itself always fits into BW bits, though,
   // so we immediately truncate it.
-  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
-  APInt Mod(BW + 1, 0);
-  Mod.setBit(BW - Mult2);  // Mod = N / D
-  APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
+  APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
+  APInt I = AD.multiplicativeInverse().zext(BW);
 
   // 4. Compute the minimum unsigned root of the equation:
   // I * (B / D) mod (N / D)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 062132c8304b06..e7a46700962157 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5202,9 +5202,7 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
     // Calculate the multiplicative inverse modulo BW.
     // 2^W requires W + 1 bits, so we have to extend and then truncate.
     unsigned W = Divisor.getBitWidth();
-    APInt Factor = Divisor.zext(W + 1)
-                       .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
-                       .trunc(W);
+    APInt Factor = Divisor.multiplicativeInverse();
     Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
     Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
     return true;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 8bb9541bfe1027..8aff2a081024c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -6050,11 +6050,7 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
       Divisor.ashrInPlace(Shift);
       UseSRA = true;
     }
-    // Calculate the multiplicative inverse, using Newton's method.
-    APInt t;
-    APInt Factor = Divisor;
-    while ((t = Divisor * Factor) != 1)
-      Factor *= APInt(Divisor.getBitWidth(), 2) - t;
+    APInt Factor = Divisor.multiplicativeInverse();
     Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
     Factors.push_back(DAG.getConstant(Factor, dl, SVT));
     return true;
@@ -6643,10 +6639,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
     // P = inv(D0, 2^W)
     // 2^W requires W + 1 bits, so we have to extend and then truncate.
     unsigned W = D.getBitWidth();
-    APInt P = D0.zext(W + 1)
-                  .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
-                  .trunc(W);
-    assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
+    APInt P = D0.multiplicativeInverse();
     assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
 
     // Q = floor((2^W - 1) u/ D)
@@ -6901,10 +6894,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
     // P = inv(D0, 2^W)
     // 2^W requires W + 1 bits, so we have to extend and then truncate.
     unsigned W = D.getBitWidth();
-    APInt P = D0.zext(W + 1)
-                  .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
-                  .trunc(W);
-    assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
+    APInt P = D0.multiplicativeInverse();
     assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
 
     // A = floor((2^(W - 1) - 1) / D0) & -2^K
@@ -7630,7 +7620,7 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
 //
 // For division, we can compute the remainder using the algorithm described
 // above, subtract it from the dividend to get an exact multiple of Constant.
-// Then multiply that extact multiply by the multiplicative inverse modulo
+// Then multiply that exact multiply by the multiplicative inverse modulo
 // (1 << (BitWidth / 2)) to get the quotient.
 
 // If Constant is even, we can shift right the dividend and the divisor by the
@@ -7765,10 +7755,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
 
     // Multiply by the multiplicative inverse of the divisor modulo
     // (1 << BitWidth).
-    APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
-    APInt MulFactor = Divisor.zext(BitWidth + 1);
-    MulFactor = MulFactor.multiplicativeInverse(Mod);
-    MulFactor = MulFactor.trunc(BitWidth);
+    APInt MulFactor = Divisor.multiplicativeInverse();
 
     SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
                                    DAG.getConstant(MulFactor, dl, VT));
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index c20609748dc97c..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1289,6 +1289,19 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const {
   return std::move(t[i]);
 }
 
+/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
+APInt APInt::multiplicativeInverse() const {
+  assert((*this)[0] &&
+         "multiplicative inverse is only defined for odd numbers!");
+
+  // Use Newton's method.
+  APInt Factor = *this;
+  APInt T;
+  while (!(T = *this * Factor).isOne())
+    Factor *= 2 - T;
+  return Factor;
+}
+
 /// Implementation of Knuth's Algorithm D (Division of nonnegative integers)
 /// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The
 /// variables here have the same names as in the algorithm. Comments explain
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index d5ef63e38e2790..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3257,9 +3257,10 @@ TEST(APIntTest, MultiplicativeInverseExaustive) {
               .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
               .trunc(BitWidth);
       APInt One = V * MulInv;
-      if (!V.isZero() && V.countr_zero() == 0) {
+      if (V[0]) {
         // Multiplicative inverse exists for all odd numbers.
         EXPECT_TRUE(One.isOne());
+        EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
       } else {
         // Multiplicative inverse does not exist for even numbers (and 0).
         EXPECT_TRUE(MulInv.isZero());

@jayfoad
Copy link
Contributor Author

jayfoad commented Apr 4, 2024

The old multiplicativeInverse can be removed after this, but I would do it as a separate commit.

@jayfoad jayfoad requested a review from topperc April 4, 2024 08:48
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jayfoad jayfoad merged commit 1b76120 into llvm:main Apr 4, 2024
@jayfoad jayfoad deleted the mult-inv branch April 4, 2024 15:11
@jayfoad
Copy link
Contributor Author

jayfoad commented Apr 4, 2024

The old multiplicativeInverse can be removed after this, but I would do it as a separate commit.

#87644

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:adt llvm:analysis Includes value tracking, cost tables and constant folding llvm:globalisel llvm:SelectionDAG SelectionDAGISel as well llvm:support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants