Skip to content

[SCEV] Add predicated version of getSymbolicMaxBackedgeTakenCount. #93498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,13 @@ class ScalarEvolution {
return getBackedgeTakenCount(L, SymbolicMaximum);
}

/// Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of
/// SCEV predicates to Predicates that are required to be true in order for
/// the answer to be correct. Predicates can be checked with run-time
/// checks and can be used to perform loop versioning.
const SCEV *getPredicatedSymbolicMaxBackedgeTakenCount(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks!

const Loop *L, SmallVector<const SCEVPredicate *, 4> &Predicates);

/// Return true if the backedge taken count is either the value returned by
/// getConstantMaxBackedgeTakenCount or zero.
bool isBackedgeTakenCountMaxOrZero(const Loop *L);
Expand Down Expand Up @@ -1549,7 +1556,9 @@ class ScalarEvolution {
ScalarEvolution *SE) const;

/// Get the symbolic max backedge taken count for the loop.
const SCEV *getSymbolicMax(const Loop *L, ScalarEvolution *SE);
const SCEV *
getSymbolicMax(const Loop *L, ScalarEvolution *SE,
SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);

/// Get the symbolic max backedge taken count for the particular loop exit.
const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
Expand Down Expand Up @@ -1746,7 +1755,7 @@ class ScalarEvolution {

/// Similar to getBackedgeTakenInfo, but will add predicates as required
/// with the purpose of returning complete information.
const BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L);
BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L);

/// Compute the number of times the specified loop will iterate.
/// If AllowPredicates is set, we will create new SCEV predicates as
Expand Down Expand Up @@ -2311,6 +2320,9 @@ class PredicatedScalarEvolution {
/// Get the (predicated) backedge count for the analyzed loop.
const SCEV *getBackedgeTakenCount();

/// Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV *getSymbolicMaxBackedgeTakenCount();

/// Adds a new predicate.
void addPredicate(const SCEVPredicate &Pred);

Expand Down Expand Up @@ -2379,6 +2391,9 @@ class PredicatedScalarEvolution {

/// The backedge taken count.
const SCEV *BackedgeCount = nullptr;

/// The symbolic backedge taken count.
const SCEV *SymbolicMaxBackedgeCount = nullptr;
};

template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
Expand Down
48 changes: 44 additions & 4 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8295,6 +8295,11 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
llvm_unreachable("Invalid ExitCountKind!");
}

const SCEV *ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount(
const Loop *L, SmallVector<const SCEVPredicate *, 4> &Preds) {
return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
}

bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
}
Expand All @@ -8311,7 +8316,7 @@ static void PushLoopPHIs(const Loop *L,
Worklist.push_back(&PN);
}

const ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
auto &BTI = getBackedgeTakenInfo(L);
if (BTI.hasFullInfo())
Expand Down Expand Up @@ -8644,9 +8649,9 @@ ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
return getConstantMax();
}

const SCEV *
ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
ScalarEvolution *SE) {
const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I'm wrong but this change looks to have taken inspiration (and a little code) from #88385. I'm happy for the collaboration but LLVM rules suggest it would have been polite to reference my PR or indicate co-authorship?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, this should have been referenced, apologies for that! To rectify that, how about reverting the patch and then recommitting with an adjusted message to include an attribution like

The extension of getSymbolicMax to support predication is inspired by David Sherwood's (@david-arm) version in https://github.com/llvm/llvm-project/pull/88385

Happy to adjust as needed, if you think a different wording would be better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok, there's no need to revert the patch and no harm done! Perhaps just something to remember in future that's all. Thanks!

const Loop *L, ScalarEvolution *SE,
SmallVector<const SCEVPredicate *, 4> *Predicates) {
if (!SymbolicMax) {
// Form an expression for the maximum exit count possible for this loop. We
// merge the max and exact information to approximate a version of
Expand All @@ -8661,6 +8666,12 @@ ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
"We should only have known counts for exiting blocks that "
"dominate latch!");
ExitCounts.push_back(ExitCount);
if (Predicates)
for (const auto *P : ENT.Predicates)
Predicates->push_back(P);

assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
"Predicate should be always true!");
}
}
if (ExitCounts.empty())
Expand Down Expand Up @@ -13609,6 +13620,24 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
P->print(OS, 4);
}

Preds.clear();
auto *PredSymbolicMax =
SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds);
if (SymbolicBTC != PredSymbolicMax) {
OS << "Loop ";
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
OS << ": ";
if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
OS << "Predicated symbolic max backedge-taken count is ";
PrintSCEVWithTypeHint(OS, PredSymbolicMax);
} else
OS << "Unpredictable predicated symbolic max backedge-taken count.";
OS << "\n";
OS << " Predicates:\n";
for (const auto *P : Preds)
P->print(OS, 4);
}

if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
OS << "Loop ";
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
Expand Down Expand Up @@ -14822,6 +14851,17 @@ const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
return BackedgeCount;
}

const SCEV *PredicatedScalarEvolution::getSymbolicMaxBackedgeTakenCount() {
if (!SymbolicMaxBackedgeCount) {
SmallVector<const SCEVPredicate *, 4> Preds;
SymbolicMaxBackedgeCount =
SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
for (const auto *P : Preds)
addPredicate(*P);
}
return SymbolicMaxBackedgeCount;
}

void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
if (Preds->implies(&Pred))
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
;
entry:
br label %header
Expand Down Expand Up @@ -52,6 +55,9 @@ define void @test2(i64 %x, ptr %a) {
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
;
entry:
br label %header
Expand Down
Loading