Skip to content

Commit f002212

Browse files
committed
[SCEV] Use a SmallPtrSet as a temporary union predicate; NFC
Summary: Instead of creating and destroying SCEVUnionPredicate instances (which internally creates and destroys a DenseMap), use temporary SmallPtrSet instances of remember the set of predicates that will get reified into a SCEVUnionPredicate. Reviewers: silviu.baranga, sbaranga Subscribers: sanjoy, mcrosier, llvm-commits, mzolotukhin Differential Revision: https://reviews.llvm.org/D25000 llvm-svn: 282606
1 parent 3862365 commit f002212

File tree

2 files changed

+90
-63
lines changed

2 files changed

+90
-63
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -551,19 +551,36 @@ class ScalarEvolution {
551551
const SCEV *ExactNotTaken;
552552
const SCEV *MaxNotTaken;
553553

554-
/// A predicate union guard for this ExitLimit. The result is only
555-
/// valid if this predicate evaluates to 'true' at run-time.
556-
SCEVUnionPredicate Predicate;
554+
/// A set of predicate guards for this ExitLimit. The result is only valid
555+
/// if all of the predicates in \c Predicates evaluate to 'true' at
556+
/// run-time.
557+
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
558+
559+
void addPredicate(const SCEVPredicate *P) {
560+
assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
561+
Predicates.insert(P);
562+
}
557563

558564
/*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {}
559565

560-
ExitLimit(const SCEV *E, const SCEV *M, SCEVUnionPredicate &P)
561-
: ExactNotTaken(E), MaxNotTaken(M), Predicate(P) {
566+
ExitLimit(
567+
const SCEV *E, const SCEV *M,
568+
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
569+
: ExactNotTaken(E), MaxNotTaken(M) {
562570
assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
563571
!isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
564572
"Exact is not allowed to be less precise than Max");
573+
for (auto *PredSet : PredSetList)
574+
for (auto *P : *PredSet)
575+
addPredicate(P);
565576
}
566577

578+
ExitLimit(const SCEV *E, const SCEV *M,
579+
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
580+
: ExitLimit(E, M, {&PredSet}) {}
581+
582+
ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {}
583+
567584
/// Test whether this ExitLimit contains any computed information, or
568585
/// whether it's all SCEVCouldNotCompute values.
569586
bool hasAnyInfo() const {
@@ -1581,9 +1598,9 @@ class ScalarEvolution {
15811598
SCEVUnionPredicate &A);
15821599
/// Tries to convert the \p S expression to an AddRec expression,
15831600
/// adding additional predicates to \p Preds as required.
1584-
const SCEVAddRecExpr *
1585-
convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
1586-
SCEVUnionPredicate &Preds);
1601+
const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates(
1602+
const SCEV *S, const Loop *L,
1603+
SmallPtrSetImpl<const SCEVPredicate *> &Preds);
15871604

15881605
private:
15891606
/// Compute the backedge taken count knowing the interval difference, the

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5656,11 +5656,14 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
56565656
[&](const EdgeExitInfo &EEI) {
56575657
BasicBlock *ExitBB = EEI.first;
56585658
const ExitLimit &EL = EEI.second;
5659-
if (EL.Predicate.isAlwaysTrue())
5659+
if (EL.Predicates.empty())
56605660
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
5661-
return ExitNotTakenInfo(
5662-
ExitBB, EL.ExactNotTaken,
5663-
llvm::make_unique<SCEVUnionPredicate>(std::move(EL.Predicate)));
5661+
5662+
std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
5663+
for (auto *Pred : EL.Predicates)
5664+
Predicate->add(Pred);
5665+
5666+
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
56645667
});
56655668
}
56665669

@@ -5691,7 +5694,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
56915694
BasicBlock *ExitBB = ExitingBlocks[i];
56925695
ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
56935696

5694-
assert((AllowPredicates || EL.Predicate.isAlwaysTrue()) &&
5697+
assert((AllowPredicates || EL.Predicates.empty()) &&
56955698
"Predicated exit limit when predicates are not allowed!");
56965699

56975700
// 1. For each exit that can be computed, add an entry to ExitCounts.
@@ -5861,9 +5864,6 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
58615864
BECount = EL0.ExactNotTaken;
58625865
}
58635866

5864-
SCEVUnionPredicate NP;
5865-
NP.add(&EL0.Predicate);
5866-
NP.add(&EL1.Predicate);
58675867
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able
58685868
// to be more aggressive when computing BECount than when computing
58695869
// MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
@@ -5873,7 +5873,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
58735873
!isa<SCEVCouldNotCompute>(BECount))
58745874
MaxBECount = BECount;
58755875

5876-
return ExitLimit(BECount, MaxBECount, NP);
5876+
return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
58775877
}
58785878
if (BO->getOpcode() == Instruction::Or) {
58795879
// Recurse on the operands of the or.
@@ -5912,10 +5912,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
59125912
BECount = EL0.ExactNotTaken;
59135913
}
59145914

5915-
SCEVUnionPredicate NP;
5916-
NP.add(&EL0.Predicate);
5917-
NP.add(&EL1.Predicate);
5918-
return ExitLimit(BECount, MaxBECount, NP);
5915+
return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
59195916
}
59205917
}
59215918

@@ -6300,8 +6297,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
63006297
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
63016298
const SCEV *UpperBound =
63026299
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
6303-
SCEVUnionPredicate P;
6304-
return ExitLimit(getCouldNotCompute(), UpperBound, P);
6300+
return ExitLimit(getCouldNotCompute(), UpperBound);
63056301
}
63066302

63076303
return getCouldNotCompute();
@@ -7062,7 +7058,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
70627058
// effectively V != 0. We know and take advantage of the fact that this
70637059
// expression only being used in a comparison by zero context.
70647060

7065-
SCEVUnionPredicate P;
7061+
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
70667062
// If the value is a constant
70677063
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
70687064
// If the value is already zero, the branch will execute zero times.
@@ -7075,7 +7071,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
70757071
// Try to make this an AddRec using runtime tests, in the first X
70767072
// iterations of this loop, where X is the SCEV expression found by the
70777073
// algorithm below.
7078-
AddRec = convertSCEVToAddRecWithPredicates(V, L, P);
7074+
AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
70797075

70807076
if (!AddRec || AddRec->getLoop() != L)
70817077
return getCouldNotCompute();
@@ -7097,7 +7093,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
70977093
// should not accept a root of 2.
70987094
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
70997095
if (Val->isZero())
7100-
return ExitLimit(R1, R1, P); // We found a quadratic root!
7096+
return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
71017097
}
71027098
}
71037099
return getCouldNotCompute();
@@ -7154,7 +7150,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
71547150
else
71557151
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
71567152
: -CR.getUnsignedMin());
7157-
return ExitLimit(Distance, MaxBECount, P);
7153+
return ExitLimit(Distance, MaxBECount, Predicates);
71587154
}
71597155

71607156
// As a special case, handle the instance where Step is a positive power of
@@ -7209,7 +7205,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
72097205

72107206
const SCEV *Limit =
72117207
getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
7212-
return ExitLimit(Limit, Limit, P);
7208+
return ExitLimit(Limit, Limit, Predicates);
72137209
}
72147210
}
72157211

@@ -7222,14 +7218,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
72227218
loopHasNoAbnormalExits(AddRec->getLoop())) {
72237219
const SCEV *Exact =
72247220
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
7225-
return ExitLimit(Exact, Exact, P);
7221+
return ExitLimit(Exact, Exact, Predicates);
72267222
}
72277223

72287224
// Then, try to solve the above equation provided that Start is constant.
72297225
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
72307226
const SCEV *E = SolveLinEquationWithOverflow(
72317227
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
7232-
return ExitLimit(E, E, P);
7228+
return ExitLimit(E, E, Predicates);
72337229
}
72347230
return getCouldNotCompute();
72357231
}
@@ -8634,7 +8630,7 @@ ScalarEvolution::ExitLimit
86348630
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
86358631
const Loop *L, bool IsSigned,
86368632
bool ControlsExit, bool AllowPredicates) {
8637-
SCEVUnionPredicate P;
8633+
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
86388634
// We handle only IV < Invariant
86398635
if (!isLoopInvariant(RHS, L))
86408636
return getCouldNotCompute();
@@ -8646,7 +8642,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
86468642
// Try to make this an AddRec using runtime tests, in the first X
86478643
// iterations of this loop, where X is the SCEV expression found by the
86488644
// algorithm below.
8649-
IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
8645+
IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
86508646
PredicatedIV = true;
86518647
}
86528648

@@ -8762,14 +8758,14 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
87628758
if (isa<SCEVCouldNotCompute>(MaxBECount))
87638759
MaxBECount = BECount;
87648760

8765-
return ExitLimit(BECount, MaxBECount, P);
8761+
return ExitLimit(BECount, MaxBECount, Predicates);
87668762
}
87678763

87688764
ScalarEvolution::ExitLimit
87698765
ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
87708766
const Loop *L, bool IsSigned,
87718767
bool ControlsExit, bool AllowPredicates) {
8772-
SCEVUnionPredicate P;
8768+
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
87738769
// We handle only IV > Invariant
87748770
if (!isLoopInvariant(RHS, L))
87758771
return getCouldNotCompute();
@@ -8779,7 +8775,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
87798775
// Try to make this an AddRec using runtime tests, in the first X
87808776
// iterations of this loop, where X is the SCEV expression found by the
87818777
// algorithm below.
8782-
IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
8778+
IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
87838779

87848780
// Avoid weird loops
87858781
if (!IV || IV->getLoop() != L || !IV->isAffine())
@@ -8839,7 +8835,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
88398835
if (isa<SCEVCouldNotCompute>(MaxBECount))
88408836
MaxBECount = BECount;
88418837

8842-
return ExitLimit(BECount, MaxBECount, P);
8838+
return ExitLimit(BECount, MaxBECount, Predicates);
88438839
}
88448840

88458841
const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@@ -10161,25 +10157,34 @@ namespace {
1016110157

1016210158
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1016310159
public:
10164-
// Rewrites \p S in the context of a loop L and the predicate A.
10165-
// If Assume is true, rewrite is free to add further predicates to A
10166-
// such that the result will be an AddRecExpr.
10160+
/// Rewrites \p S in the context of a loop L and the SCEV predication
10161+
/// infrastructure.
10162+
///
10163+
/// If \p Pred is non-null, the SCEV expression is rewritten to respect the
10164+
/// equivalences present in \p Pred.
10165+
///
10166+
/// If \p NewPreds is non-null, rewrite is free to add further predicates to
10167+
/// \p NewPreds such that the result will be an AddRecExpr.
1016710168
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
10168-
SCEVUnionPredicate &A, bool Assume) {
10169-
SCEVPredicateRewriter Rewriter(L, SE, A, Assume);
10169+
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
10170+
SCEVUnionPredicate *Pred) {
10171+
SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
1017010172
return Rewriter.visit(S);
1017110173
}
1017210174

1017310175
SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
10174-
SCEVUnionPredicate &P, bool Assume)
10175-
: SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
10176+
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
10177+
SCEVUnionPredicate *Pred)
10178+
: SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
1017610179

1017710180
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
10178-
auto ExprPreds = P.getPredicatesForExpr(Expr);
10179-
for (auto *Pred : ExprPreds)
10180-
if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
10181-
if (IPred->getLHS() == Expr)
10182-
return IPred->getRHS();
10181+
if (Pred) {
10182+
auto ExprPreds = Pred->getPredicatesForExpr(Expr);
10183+
for (auto *Pred : ExprPreds)
10184+
if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
10185+
if (IPred->getLHS() == Expr)
10186+
return IPred->getRHS();
10187+
}
1018310188

1018410189
return Expr;
1018510190
}
@@ -10220,40 +10225,41 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1022010225
bool addOverflowAssumption(const SCEVAddRecExpr *AR,
1022110226
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
1022210227
auto *A = SE.getWrapPredicate(AR, AddedFlags);
10223-
if (!Assume) {
10228+
if (!NewPreds) {
1022410229
// Check if we've already made this assumption.
10225-
if (P.implies(A))
10226-
return true;
10227-
return false;
10230+
return Pred && Pred->implies(A);
1022810231
}
10229-
P.add(A);
10232+
NewPreds->insert(A);
1023010233
return true;
1023110234
}
1023210235

10233-
SCEVUnionPredicate &P;
10236+
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
10237+
SCEVUnionPredicate *Pred;
1023410238
const Loop *L;
10235-
bool Assume;
1023610239
};
1023710240
} // end anonymous namespace
1023810241

1023910242
const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
1024010243
SCEVUnionPredicate &Preds) {
10241-
return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false);
10244+
return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
1024210245
}
1024310246

10244-
const SCEVAddRecExpr *
10245-
ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
10246-
SCEVUnionPredicate &Preds) {
10247-
SCEVUnionPredicate TransformPreds;
10248-
S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true);
10247+
const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
10248+
const SCEV *S, const Loop *L,
10249+
SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
10250+
10251+
SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
10252+
S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
1024910253
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
1025010254

1025110255
if (!AddRec)
1025210256
return nullptr;
1025310257

1025410258
// Since the transformation was successful, we can now transfer the SCEV
1025510259
// predicates.
10256-
Preds.add(&TransformPreds);
10260+
for (auto *P : TransformPreds)
10261+
Preds.insert(P);
10262+
1025710263
return AddRec;
1025810264
}
1025910265

@@ -10480,11 +10486,15 @@ bool PredicatedScalarEvolution::hasNoOverflow(
1048010486

1048110487
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1048210488
const SCEV *Expr = this->getSCEV(V);
10483-
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds);
10489+
SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
10490+
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
1048410491

1048510492
if (!New)
1048610493
return nullptr;
1048710494

10495+
for (auto *P : NewPreds)
10496+
Preds.add(P);
10497+
1048810498
updateGeneration();
1048910499
RewriteMap[SE.getSCEV(V)] = {Generation, New};
1049010500
return New;

0 commit comments

Comments
 (0)