-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[Analysis] Add getPredicatedExitCount to ScalarEvolution #105649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis Author: David Sherwood (david-arm) ChangesDue to a reviewer request on PR #88385 I have created this patch The only way to test this patch is via unit tests that I have Full diff: https://github.com/llvm/llvm-project/pull/105649.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 5154e2f6659c12..03fb11993448e5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
ExitCountKind Kind = Exact);
+ /// Same as above except this uses the predicated backedge taken info and
+ /// may require predicates.
+ const SCEV *
+ getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVector<const SCEVPredicate *, 4> *Predicates,
+ ExitCountKind Kind = Exact);
+
/// If the specified loop has a predictable backedge-taken count, return it,
/// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
/// the number of times the loop header will be branched to from within the
@@ -1562,16 +1569,19 @@ class ScalarEvolution {
/// Return the number of times this loop exit may fall through to the back
/// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
/// this block before this number of iterations, but may exit via another
- /// block.
- const SCEV *getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ /// block. If \p Predicates is null the function returns CouldNotCompute if
+ /// predicates are required, otherwise it fills in the required predicates.
+ const SCEV *
+ getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
/// Get the constant max backedge taken count for the loop.
const SCEV *getConstantMax(ScalarEvolution *SE) const;
/// Get the constant max backedge taken count for the particular loop exit.
- const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ const SCEV *getConstantMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
/// Get the symbolic max backedge taken count for the loop.
const SCEV *
@@ -1579,8 +1589,9 @@ class ScalarEvolution {
SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);
/// Get the symbolic max backedge taken count for the particular loop exit.
- const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ const SCEV *getSymbolicMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
/// Return true if the number of times this backedge is taken is either the
/// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a19358dee8ef49..3726ff323630ab 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8249,6 +8249,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
llvm_unreachable("Invalid ExitCountKind!");
}
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+ const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+ switch (Kind) {
+ case Exact:
+ return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+ Predicates);
+ case SymbolicMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+ Predicates);
+ case ConstantMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+ Predicates);
+ };
+ llvm_unreachable("Invalid ExitCountKind!");
+}
+
const SCEV *
ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
SmallVector<const SCEVPredicate *, 4> &Preds) {
@@ -8578,30 +8595,53 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
}
/// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ExactNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return ENT.ExactNotTaken;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return ENT.ExactNotTaken;
+ }
+ }
return SE->getCouldNotCompute();
}
const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ConstantMaxNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return ENT.ConstantMaxNotTaken;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return ENT.ConstantMaxNotTaken;
+ }
+ }
return SE->getCouldNotCompute();
}
const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.SymbolicMaxNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return ENT.SymbolicMaxNotTaken;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return ENT.SymbolicMaxNotTaken;
+ }
+ }
return SE->getCouldNotCompute();
}
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index d4d90d80f4cea1..a9bd4789707012 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1707,4 +1707,56 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
});
}
+TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
+ LLVMContext C;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M = parseAssemblyString(R"(
+define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+entry:
+ %cmp7 = icmp sgt i64 %end, 0
+ br i1 %cmp7, label %for.body, label %exit
+
+for.body:
+ %conv9 = phi i64 [ %conv, %for.body ], [ 0, %entry ]
+ %i.08 = phi i16 [ %inc, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i32, ptr %src, i64 %conv9
+ %0 = load i32, ptr %arrayidx, align 4
+ %arrayidx3 = getelementptr inbounds i32, ptr %dest, i64 %conv9
+ %1 = load i32, ptr %arrayidx3, align 4
+ %add = add i32 %1, %0
+ store i32 %add, ptr %arrayidx3, align 4
+ %inc = add i16 %i.08, 1
+ %conv = zext i16 %inc to i64
+ %cmp = icmp ult i64 %conv, %end
+ br i1 %cmp, label %for.body, label %exit
+
+exit:
+ ret void
+})",
+ Err, C);
+
+ ASSERT_TRUE(M && "Could not parse module?");
+ ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+ runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+ BasicBlock &EntryBB = F.getEntryBlock();
+ BasicBlock *ForBodyBB = nullptr;
+ Loop *Loop = nullptr;
+ for (BasicBlock *Succ : successors(&EntryBB)) {
+ Loop = LI.getLoopFor(Succ);
+ if (Loop) {
+ ForBodyBB = Loop->getHeader();
+ break;
+ }
+ }
+ ASSERT_TRUE(Loop && "Couldn't find the loop!");
+ ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ const SCEV *ExitCount = SE.getPredicatedExitCount(
+ Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
+ ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+ ASSERT_FALSE(Predicates.empty());
+ });
+}
+
} // end namespace llvm
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks ok to me, but not familiar with the context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but please wait a bit in case @fhahn has comments.
ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!"); | ||
SmallVector<const SCEVPredicate *, 4> Predicates; | ||
const SCEV *ExitCount = SE.getPredicatedExitCount( | ||
Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to test the symbolicmax
ConstantMax
cases.
Not sure what @nikic's preference would be, but it might make sense to print the predicated exit counts like we print regular exit counts, if they are different to the regular exit counts, i.e.
Loop %header: <multiple exits> Unpredictable predicated backedge-taken count.
predicated exit count for header: Foo
predicated exit count for latch: Bar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added additional unit tests for the symbolic and constant maximums, as well as printing out the predicated exit counts if we cannot compute the unpredicated ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please convert the unit tests into normal lit analysis tests, now that the printing allows it?
You can use -scalar-evolution-classify-expressions=0
to suppress most of the superfluous output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've removed the unit test and added a single LLVM IR test. I wrote the test to have multiple early exits to try and get greater coverage. Hope this works!
Due to a reviewer request on PR llvm#88385 I have created this patch to add a getPredicatedExitCount function, which is similar to getExitCount except that it uses the predicated backedge taken information. With PR llvm#88385 we will start to care about more loops with multiple exits, and want the ability to query exit counts for a particular exiting block. Such loops may require predicates in order to be vectorised. The only way to test this patch is via unit tests that I have added to unittests/Analysis/ScalarEvolutionTest.cpp.
* Print out the predicated exact and symbolic exit counts for blocks if the unpredicated exit count cannot be computed. * Add unit tests for the symbolic and constant maximums.
7a83aeb
to
106a02e
Compare
Rebase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
One more potential suggestion inline to try to reduce some duplication, not sure if it is clearly worth it though
for (const auto &ENT : ExitNotTaken) | ||
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) | ||
return ENT.ConstantMaxNotTaken; | ||
if (ENT.ExitingBlock == ExitingBlock) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to avoid duplicating this logic multiple times by passing a lambda that returns the desired count (or pass ExitCountKind
and select based on that).
Not sure if that would overall be much simpler/desirable though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a good suggestion. I've tried to do this and it does seem neater.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.
The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.