Skip to content

[Analysis] Add getPredicatedExitCount to ScalarEvolution #105649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 2, 2024

Conversation

david-arm
Copy link
Contributor

Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.

@david-arm david-arm requested a review from nikic as a code owner August 22, 2024 12:39
@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label Aug 22, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2024

@llvm/pr-subscribers-llvm-analysis

Author: David Sherwood (david-arm)

Changes

Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.


Full diff: https://github.com/llvm/llvm-project/pull/105649.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+18-7)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+51-11)
  • (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+52)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 5154e2f6659c12..03fb11993448e5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
   const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
                            ExitCountKind Kind = Exact);
 
+  /// Same as above except this uses the predicated backedge taken info and
+  /// may require predicates.
+  const SCEV *
+  getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+                         SmallVector<const SCEVPredicate *, 4> *Predicates,
+                         ExitCountKind Kind = Exact);
+
   /// If the specified loop has a predictable backedge-taken count, return it,
   /// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
   /// the number of times the loop header will be branched to from within the
@@ -1562,16 +1569,19 @@ class ScalarEvolution {
     /// Return the number of times this loop exit may fall through to the back
     /// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
     /// this block before this number of iterations, but may exit via another
-    /// block.
-    const SCEV *getExact(const BasicBlock *ExitingBlock,
-                         ScalarEvolution *SE) const;
+    /// block. If \p Predicates is null the function returns CouldNotCompute if
+    /// predicates are required, otherwise it fills in the required predicates.
+    const SCEV *
+    getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+             SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Get the constant max backedge taken count for the loop.
     const SCEV *getConstantMax(ScalarEvolution *SE) const;
 
     /// Get the constant max backedge taken count for the particular loop exit.
-    const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getConstantMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Get the symbolic max backedge taken count for the loop.
     const SCEV *
@@ -1579,8 +1589,9 @@ class ScalarEvolution {
                    SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);
 
     /// Get the symbolic max backedge taken count for the particular loop exit.
-    const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getSymbolicMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a19358dee8ef49..3726ff323630ab 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8249,6 +8249,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
   llvm_unreachable("Invalid ExitCountKind!");
 }
 
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+    const Loop *L, const BasicBlock *ExitingBlock,
+    SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+  switch (Kind) {
+  case Exact:
+    return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+                                                      Predicates);
+  case SymbolicMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+                                                            Predicates);
+  case ConstantMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+                                                            Predicates);
+  };
+  llvm_unreachable("Invalid ExitCountKind!");
+}
+
 const SCEV *
 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
                                                  SmallVector<const SCEVPredicate *, 4> &Preds) {
@@ -8578,30 +8595,53 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
 }
 
 /// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
-                                             ScalarEvolution *SE) const {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ExactNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.ExactNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.ExactNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ConstantMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.ConstantMaxNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.ConstantMaxNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.SymbolicMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.SymbolicMaxNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.SymbolicMaxNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index d4d90d80f4cea1..a9bd4789707012 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1707,4 +1707,56 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(R"(
+define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+entry:
+  %cmp7 = icmp sgt i64 %end, 0
+  br i1 %cmp7, label %for.body, label %exit
+
+for.body:
+  %conv9 = phi i64 [ %conv, %for.body ], [ 0, %entry ]
+  %i.08 = phi i16 [ %inc, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %src, i64 %conv9
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx3 = getelementptr inbounds i32, ptr %dest, i64 %conv9
+  %1 = load i32, ptr %arrayidx3, align 4
+  %add = add i32 %1, %0
+  store i32 %add, ptr %arrayidx3, align 4
+  %inc = add i16 %i.08, 1
+  %conv = zext i16 %inc to i64
+  %cmp = icmp ult i64 %conv, %end
+  br i1 %cmp, label %for.body, label %exit
+
+exit:
+  ret void
+})",
+                                                  Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    BasicBlock &EntryBB = F.getEntryBlock();
+    BasicBlock *ForBodyBB = nullptr;
+    Loop *Loop = nullptr;
+    for (BasicBlock *Succ : successors(&EntryBB)) {
+      Loop = LI.getLoopFor(Succ);
+      if (Loop) {
+        ForBodyBB = Loop->getHeader();
+        break;
+      }
+    }
+    ASSERT_TRUE(Loop && "Couldn't find the loop!");
+    ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
+    SmallVector<const SCEVPredicate *, 4> Predicates;
+    const SCEV *ExitCount = SE.getPredicatedExitCount(
+        Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
+    ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+    ASSERT_FALSE(Predicates.empty());
+  });
+}
+
 }  // end namespace llvm

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok to me, but not familiar with the context.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but please wait a bit in case @fhahn has comments.

ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
SmallVector<const SCEVPredicate *, 4> Predicates;
const SCEV *ExitCount = SE.getPredicatedExitCount(
Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to test the symbolicmax ConstantMax cases.

Not sure what @nikic's preference would be, but it might make sense to print the predicated exit counts like we print regular exit counts, if they are different to the regular exit counts, i.e.

Loop %header: <multiple exits> Unpredictable predicated backedge-taken count.
    predicated exit count for header: Foo
    predicated exit count for latch: Bar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added additional unit tests for the symbolic and constant maximums, as well as printing out the predicated exit counts if we cannot compute the unpredicated ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please convert the unit tests into normal lit analysis tests, now that the printing allows it?

You can use -scalar-evolution-classify-expressions=0 to suppress most of the superfluous output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've removed the unit test and added a single LLVM IR test. I wrote the test to have multiple early exits to try and get greater coverage. Hope this works!

Due to a reviewer request on PR llvm#88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR llvm#88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.
* Print out the predicated exact and symbolic exit counts for
blocks if the unpredicated exit count cannot be computed.
* Add unit tests for the symbolic and constant maximums.
@david-arm
Copy link
Contributor Author

Rebase

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

One more potential suggestion inline to try to reduce some duplication, not sure if it is clearly worth it though

for (const auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
return ENT.ConstantMaxNotTaken;
if (ENT.ExitingBlock == ExitingBlock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to avoid duplicating this logic multiple times by passing a lambda that returns the desired count (or pass ExitCountKind and select based on that).

Not sure if that would overall be much simpler/desirable though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good suggestion. I've tried to do this and it does seem neater.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@david-arm david-arm merged commit df3d70b into llvm:main Sep 2, 2024
8 checks passed
@david-arm david-arm deleted the scev_exit_count branch October 3, 2024 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants