Skip to content

[SCEV] Look through multiply in computeConstantDifference() #103051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 14, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Aug 13, 2024

Inside computeConstantDifference(), handle the case where both sides are of the form C * %x, in which case we can strip off the common multiplication (as long as we remember to multiply by it for the following difference calculation).

There is an obvious alternative implementation here, which would be to directly decompose multiplies inside the "Multiplicity" accumulation. This does work, but I've found this to be both significantly slower (because everything has to work on APInt) and more complex in implementation (e.g. because we now need to match back the new More/Less with an arbitrary factor) without providing more power in practice. As such, I went for the simpler variant here.

This is the last step to make computeConstantDifference() sufficiently powerful to replace existing uses of
cast<SCEVConstant>(getMinusSCEV()) with it.

Inside computeConstantDifference(), handle the case where both
sides are of the form `C * %x`, in which case we can strip off
the common multiplication (as long as we remember to multiply
by it for the following difference calculation).

There is an obvious alternative implementation here, which would
be to directly decompose multiplies inside the "Multiplicity"
accumulation. This does work, but I've found this to be both
significantly slower (because everything has to work on APInt)
and more complex in implementation (e.g. because we now need to
match back the new More/Less with an arbitrary factor) without
providing more power in practice. As such, I went for the simpler
variant here.

This is the last step to make computeConstantDifference()
sufficiently powerful to replace existing uses of
`cast<SCEVConstant>(getMinusSCEV())` with it.
@nikic nikic requested review from fhahn and preames August 13, 2024 12:50
@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label Aug 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

Inside computeConstantDifference(), handle the case where both sides are of the form C * %x, in which case we can strip off the common multiplication (as long as we remember to multiply by it for the following difference calculation).

There is an obvious alternative implementation here, which would be to directly decompose multiplies inside the "Multiplicity" accumulation. This does work, but I've found this to be both significantly slower (because everything has to work on APInt) and more complex in implementation (e.g. because we now need to match back the new More/Less with an arbitrary factor) without providing more power in practice. As such, I went for the simpler variant here.

This is the last step to make computeConstantDifference() sufficiently powerful to replace existing uses of
cast&lt;SCEVConstant&gt;(getMinusSCEV()) with it.


Full diff: https://github.com/llvm/llvm-project/pull/103051.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+25-3)
  • (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+9)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 78975bee4d72c4..d03bb0d2156433 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11953,9 +11953,10 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
 
   unsigned BW = getTypeSizeInBits(More->getType());
   APInt Diff(BW, 0);
+  APInt DiffMul(BW, 1);
   // Try various simplifications to reduce the difference to a constant. Limit
   // the number of allowed simplifications to keep compile-time low.
-  for (unsigned I = 0; I < 4; ++I) {
+  for (unsigned I = 0; I < 5; ++I) {
     if (More == Less)
       return Diff;
 
@@ -11980,15 +11981,36 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
       continue;
     }
 
+    // Try to match a common constant multiply.
+    auto MatchConstMul =
+        [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
+      auto *M = dyn_cast<SCEVMulExpr>(S);
+      if (!M || M->getNumOperands() != 2 ||
+          !isa<SCEVConstant>(M->getOperand(0)))
+        return std::nullopt;
+      return {
+          {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
+    };
+    if (auto MatchedMore = MatchConstMul(More)) {
+      if (auto MatchedLess = MatchConstMul(Less)) {
+        if (MatchedMore->second == MatchedLess->second) {
+          More = MatchedMore->first;
+          Less = MatchedLess->first;
+          DiffMul *= MatchedMore->second;
+          continue;
+        }
+      }
+    }
+
     // Try to cancel out common factors in two add expressions.
     SmallDenseMap<const SCEV *, int, 8> Multiplicity;
     auto Add = [&](const SCEV *S, int Mul) {
       if (auto *C = dyn_cast<SCEVConstant>(S)) {
         if (Mul == 1) {
-          Diff += C->getAPInt();
+          Diff += C->getAPInt() * DiffMul;
         } else {
           assert(Mul == -1);
-          Diff -= C->getAPInt();
+          Diff -= C->getAPInt() * DiffMul;
         }
       } else
         Multiplicity[S] += Mul;
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 42aad6ae507bf6..d4d90d80f4cea1 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1142,6 +1142,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
         %var = load i32, ptr %ptr
         %iv2pvar = add i32 %iv2, %var
         %iv2pvarp3 = add i32 %iv2pvar, 3
+        %iv2pvarm3 = mul i32 %iv2pvar, 3
+        %iv2pvarp3m3 = mul i32 %iv2pvarp3, 3
         %cmp2 = icmp sle i32 %iv2.next, %sz
         br i1 %cmp2, label %loop2.body, label %exit
       exit:
@@ -1178,6 +1180,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     // %var + {{3,+,1},+,1}
     const SCEV *ScevIV2PVarP3 =
         SE.getSCEV(getInstructionByName(F, "iv2pvarp3"));
+    // 3 * (%var + {{0,+,1},+,1})
+    const SCEV *ScevIV2PVarM3 =
+        SE.getSCEV(getInstructionByName(F, "iv2pvarm3"));
+    // 3 * (%var + {{3,+,1},+,1})
+    const SCEV *ScevIV2PVarP3M3 =
+        SE.getSCEV(getInstructionByName(F, "iv2pvarp3m3"));
 
     auto diff = [&SE](const SCEV *LHS, const SCEV *RHS) -> std::optional<int> {
       auto ConstantDiffOrNone = computeConstantDifference(SE, LHS, RHS);
@@ -1204,6 +1212,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
     EXPECT_EQ(diff(ScevIV2P3, ScevIV2), 3);
     EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), -3);
+    EXPECT_EQ(diff(ScevIV2PVarP3M3, ScevIV2PVarM3), 9);
     EXPECT_EQ(diff(ScevV0, ScevIV), std::nullopt);
     EXPECT_EQ(diff(ScevIVNext, ScevV3), std::nullopt);
     EXPECT_EQ(diff(ScevYY, ScevV3), std::nullopt);

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Just in case it helps anyone else - I couldn't remember if it was always safe to distribute a multiply over an add in two's complement, so here's the alive2 proof: https://alive2.llvm.org/ce/z/qcDLqq. The nonundef is required, but guaranteed in this case since C is a constant.

@nikic nikic merged commit 6da3361 into llvm:main Aug 14, 2024
8 of 10 checks passed
@nikic nikic deleted the scev-const-diff-3 branch August 14, 2024 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants