@@ -5656,11 +5656,14 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
5656
5656
[&](const EdgeExitInfo &EEI) {
5657
5657
BasicBlock *ExitBB = EEI.first ;
5658
5658
const ExitLimit &EL = EEI.second ;
5659
- if (EL.Predicate . isAlwaysTrue ())
5659
+ if (EL.Predicates . empty ())
5660
5660
return ExitNotTakenInfo (ExitBB, EL.ExactNotTaken , nullptr );
5661
- return ExitNotTakenInfo (
5662
- ExitBB, EL.ExactNotTaken ,
5663
- llvm::make_unique<SCEVUnionPredicate>(std::move (EL.Predicate )));
5661
+
5662
+ std::unique_ptr<SCEVUnionPredicate> Predicate (new SCEVUnionPredicate);
5663
+ for (auto *Pred : EL.Predicates )
5664
+ Predicate->add (Pred);
5665
+
5666
+ return ExitNotTakenInfo (ExitBB, EL.ExactNotTaken , std::move (Predicate));
5664
5667
});
5665
5668
}
5666
5669
@@ -5691,7 +5694,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
5691
5694
BasicBlock *ExitBB = ExitingBlocks[i];
5692
5695
ExitLimit EL = computeExitLimit (L, ExitBB, AllowPredicates);
5693
5696
5694
- assert ((AllowPredicates || EL.Predicate . isAlwaysTrue ()) &&
5697
+ assert ((AllowPredicates || EL.Predicates . empty ()) &&
5695
5698
" Predicated exit limit when predicates are not allowed!" );
5696
5699
5697
5700
// 1. For each exit that can be computed, add an entry to ExitCounts.
@@ -5861,9 +5864,6 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
5861
5864
BECount = EL0.ExactNotTaken ;
5862
5865
}
5863
5866
5864
- SCEVUnionPredicate NP;
5865
- NP.add (&EL0.Predicate );
5866
- NP.add (&EL1.Predicate );
5867
5867
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able
5868
5868
// to be more aggressive when computing BECount than when computing
5869
5869
// MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
@@ -5873,7 +5873,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
5873
5873
!isa<SCEVCouldNotCompute>(BECount))
5874
5874
MaxBECount = BECount;
5875
5875
5876
- return ExitLimit (BECount, MaxBECount, NP );
5876
+ return ExitLimit (BECount, MaxBECount, {&EL0. Predicates , &EL1. Predicates } );
5877
5877
}
5878
5878
if (BO->getOpcode () == Instruction::Or) {
5879
5879
// Recurse on the operands of the or.
@@ -5912,10 +5912,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
5912
5912
BECount = EL0.ExactNotTaken ;
5913
5913
}
5914
5914
5915
- SCEVUnionPredicate NP;
5916
- NP.add (&EL0.Predicate );
5917
- NP.add (&EL1.Predicate );
5918
- return ExitLimit (BECount, MaxBECount, NP);
5915
+ return ExitLimit (BECount, MaxBECount, {&EL0.Predicates , &EL1.Predicates });
5919
5916
}
5920
5917
}
5921
5918
@@ -6300,8 +6297,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
6300
6297
unsigned BitWidth = getTypeSizeInBits (RHS->getType ());
6301
6298
const SCEV *UpperBound =
6302
6299
getConstant (getEffectiveSCEVType (RHS->getType ()), BitWidth);
6303
- SCEVUnionPredicate P;
6304
- return ExitLimit (getCouldNotCompute (), UpperBound, P);
6300
+ return ExitLimit (getCouldNotCompute (), UpperBound);
6305
6301
}
6306
6302
6307
6303
return getCouldNotCompute ();
@@ -7062,7 +7058,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7062
7058
// effectively V != 0. We know and take advantage of the fact that this
7063
7059
// expression only being used in a comparison by zero context.
7064
7060
7065
- SCEVUnionPredicate P ;
7061
+ SmallPtrSet< const SCEVPredicate *, 4 > Predicates ;
7066
7062
// If the value is a constant
7067
7063
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
7068
7064
// If the value is already zero, the branch will execute zero times.
@@ -7075,7 +7071,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7075
7071
// Try to make this an AddRec using runtime tests, in the first X
7076
7072
// iterations of this loop, where X is the SCEV expression found by the
7077
7073
// algorithm below.
7078
- AddRec = convertSCEVToAddRecWithPredicates (V, L, P );
7074
+ AddRec = convertSCEVToAddRecWithPredicates (V, L, Predicates );
7079
7075
7080
7076
if (!AddRec || AddRec->getLoop () != L)
7081
7077
return getCouldNotCompute ();
@@ -7097,7 +7093,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7097
7093
// should not accept a root of 2.
7098
7094
const SCEV *Val = AddRec->evaluateAtIteration (R1, *this );
7099
7095
if (Val->isZero ())
7100
- return ExitLimit (R1, R1, P ); // We found a quadratic root!
7096
+ return ExitLimit (R1, R1, Predicates ); // We found a quadratic root!
7101
7097
}
7102
7098
}
7103
7099
return getCouldNotCompute ();
@@ -7154,7 +7150,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7154
7150
else
7155
7151
MaxBECount = getConstant (CountDown ? CR.getUnsignedMax ()
7156
7152
: -CR.getUnsignedMin ());
7157
- return ExitLimit (Distance, MaxBECount, P );
7153
+ return ExitLimit (Distance, MaxBECount, Predicates );
7158
7154
}
7159
7155
7160
7156
// As a special case, handle the instance where Step is a positive power of
@@ -7209,7 +7205,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7209
7205
7210
7206
const SCEV *Limit =
7211
7207
getZeroExtendExpr (getTruncateExpr (ModuloResult, NarrowTy), WideTy);
7212
- return ExitLimit (Limit, Limit, P );
7208
+ return ExitLimit (Limit, Limit, Predicates );
7213
7209
}
7214
7210
}
7215
7211
@@ -7222,14 +7218,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7222
7218
loopHasNoAbnormalExits (AddRec->getLoop ())) {
7223
7219
const SCEV *Exact =
7224
7220
getUDivExpr (Distance, CountDown ? getNegativeSCEV (Step) : Step);
7225
- return ExitLimit (Exact, Exact, P );
7221
+ return ExitLimit (Exact, Exact, Predicates );
7226
7222
}
7227
7223
7228
7224
// Then, try to solve the above equation provided that Start is constant.
7229
7225
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
7230
7226
const SCEV *E = SolveLinEquationWithOverflow (
7231
7227
StepC->getValue ()->getValue (), -StartC->getValue ()->getValue (), *this );
7232
- return ExitLimit (E, E, P );
7228
+ return ExitLimit (E, E, Predicates );
7233
7229
}
7234
7230
return getCouldNotCompute ();
7235
7231
}
@@ -8634,7 +8630,7 @@ ScalarEvolution::ExitLimit
8634
8630
ScalarEvolution::howManyLessThans (const SCEV *LHS, const SCEV *RHS,
8635
8631
const Loop *L, bool IsSigned,
8636
8632
bool ControlsExit, bool AllowPredicates) {
8637
- SCEVUnionPredicate P ;
8633
+ SmallPtrSet< const SCEVPredicate *, 4 > Predicates ;
8638
8634
// We handle only IV < Invariant
8639
8635
if (!isLoopInvariant (RHS, L))
8640
8636
return getCouldNotCompute ();
@@ -8646,7 +8642,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
8646
8642
// Try to make this an AddRec using runtime tests, in the first X
8647
8643
// iterations of this loop, where X is the SCEV expression found by the
8648
8644
// algorithm below.
8649
- IV = convertSCEVToAddRecWithPredicates (LHS, L, P );
8645
+ IV = convertSCEVToAddRecWithPredicates (LHS, L, Predicates );
8650
8646
PredicatedIV = true ;
8651
8647
}
8652
8648
@@ -8762,14 +8758,14 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
8762
8758
if (isa<SCEVCouldNotCompute>(MaxBECount))
8763
8759
MaxBECount = BECount;
8764
8760
8765
- return ExitLimit (BECount, MaxBECount, P );
8761
+ return ExitLimit (BECount, MaxBECount, Predicates );
8766
8762
}
8767
8763
8768
8764
ScalarEvolution::ExitLimit
8769
8765
ScalarEvolution::howManyGreaterThans (const SCEV *LHS, const SCEV *RHS,
8770
8766
const Loop *L, bool IsSigned,
8771
8767
bool ControlsExit, bool AllowPredicates) {
8772
- SCEVUnionPredicate P ;
8768
+ SmallPtrSet< const SCEVPredicate *, 4 > Predicates ;
8773
8769
// We handle only IV > Invariant
8774
8770
if (!isLoopInvariant (RHS, L))
8775
8771
return getCouldNotCompute ();
@@ -8779,7 +8775,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
8779
8775
// Try to make this an AddRec using runtime tests, in the first X
8780
8776
// iterations of this loop, where X is the SCEV expression found by the
8781
8777
// algorithm below.
8782
- IV = convertSCEVToAddRecWithPredicates (LHS, L, P );
8778
+ IV = convertSCEVToAddRecWithPredicates (LHS, L, Predicates );
8783
8779
8784
8780
// Avoid weird loops
8785
8781
if (!IV || IV->getLoop () != L || !IV->isAffine ())
@@ -8839,7 +8835,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
8839
8835
if (isa<SCEVCouldNotCompute>(MaxBECount))
8840
8836
MaxBECount = BECount;
8841
8837
8842
- return ExitLimit (BECount, MaxBECount, P );
8838
+ return ExitLimit (BECount, MaxBECount, Predicates );
8843
8839
}
8844
8840
8845
8841
const SCEV *SCEVAddRecExpr::getNumIterationsInRange (const ConstantRange &Range,
@@ -10161,25 +10157,34 @@ namespace {
10161
10157
10162
10158
class SCEVPredicateRewriter : public SCEVRewriteVisitor <SCEVPredicateRewriter> {
10163
10159
public:
10164
- // Rewrites \p S in the context of a loop L and the predicate A.
10165
- // If Assume is true, rewrite is free to add further predicates to A
10166
- // such that the result will be an AddRecExpr.
10160
+ // / Rewrites \p S in the context of a loop L and the SCEV predication
10161
+ // / infrastructure.
10162
+ // /
10163
+ // / If \p Pred is non-null, the SCEV expression is rewritten to respect the
10164
+ // / equivalences present in \p Pred.
10165
+ // /
10166
+ // / If \p NewPreds is non-null, rewrite is free to add further predicates to
10167
+ // / \p NewPreds such that the result will be an AddRecExpr.
10167
10168
static const SCEV *rewrite (const SCEV *S, const Loop *L, ScalarEvolution &SE,
10168
- SCEVUnionPredicate &A, bool Assume) {
10169
- SCEVPredicateRewriter Rewriter (L, SE, A, Assume);
10169
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
10170
+ SCEVUnionPredicate *Pred) {
10171
+ SCEVPredicateRewriter Rewriter (L, SE, NewPreds, Pred);
10170
10172
return Rewriter.visit (S);
10171
10173
}
10172
10174
10173
10175
SCEVPredicateRewriter (const Loop *L, ScalarEvolution &SE,
10174
- SCEVUnionPredicate &P, bool Assume)
10175
- : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
10176
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
10177
+ SCEVUnionPredicate *Pred)
10178
+ : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
10176
10179
10177
10180
const SCEV *visitUnknown (const SCEVUnknown *Expr) {
10178
- auto ExprPreds = P.getPredicatesForExpr (Expr);
10179
- for (auto *Pred : ExprPreds)
10180
- if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
10181
- if (IPred->getLHS () == Expr)
10182
- return IPred->getRHS ();
10181
+ if (Pred) {
10182
+ auto ExprPreds = Pred->getPredicatesForExpr (Expr);
10183
+ for (auto *Pred : ExprPreds)
10184
+ if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
10185
+ if (IPred->getLHS () == Expr)
10186
+ return IPred->getRHS ();
10187
+ }
10183
10188
10184
10189
return Expr;
10185
10190
}
@@ -10220,40 +10225,41 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
10220
10225
bool addOverflowAssumption (const SCEVAddRecExpr *AR,
10221
10226
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
10222
10227
auto *A = SE.getWrapPredicate (AR, AddedFlags);
10223
- if (!Assume ) {
10228
+ if (!NewPreds ) {
10224
10229
// Check if we've already made this assumption.
10225
- if (P.implies (A))
10226
- return true ;
10227
- return false ;
10230
+ return Pred && Pred->implies (A);
10228
10231
}
10229
- P. add (A);
10232
+ NewPreds-> insert (A);
10230
10233
return true ;
10231
10234
}
10232
10235
10233
- SCEVUnionPredicate &P;
10236
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
10237
+ SCEVUnionPredicate *Pred;
10234
10238
const Loop *L;
10235
- bool Assume;
10236
10239
};
10237
10240
} // end anonymous namespace
10238
10241
10239
10242
const SCEV *ScalarEvolution::rewriteUsingPredicate (const SCEV *S, const Loop *L,
10240
10243
SCEVUnionPredicate &Preds) {
10241
- return SCEVPredicateRewriter::rewrite (S, L, *this , Preds, false );
10244
+ return SCEVPredicateRewriter::rewrite (S, L, *this , nullptr , &Preds );
10242
10245
}
10243
10246
10244
- const SCEVAddRecExpr *
10245
- ScalarEvolution::convertSCEVToAddRecWithPredicates (const SCEV *S, const Loop *L,
10246
- SCEVUnionPredicate &Preds) {
10247
- SCEVUnionPredicate TransformPreds;
10248
- S = SCEVPredicateRewriter::rewrite (S, L, *this , TransformPreds, true );
10247
+ const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates (
10248
+ const SCEV *S, const Loop *L,
10249
+ SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
10250
+
10251
+ SmallPtrSet<const SCEVPredicate *, 4 > TransformPreds;
10252
+ S = SCEVPredicateRewriter::rewrite (S, L, *this , &TransformPreds, nullptr );
10249
10253
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
10250
10254
10251
10255
if (!AddRec)
10252
10256
return nullptr ;
10253
10257
10254
10258
// Since the transformation was successful, we can now transfer the SCEV
10255
10259
// predicates.
10256
- Preds.add (&TransformPreds);
10260
+ for (auto *P : TransformPreds)
10261
+ Preds.insert (P);
10262
+
10257
10263
return AddRec;
10258
10264
}
10259
10265
@@ -10480,11 +10486,15 @@ bool PredicatedScalarEvolution::hasNoOverflow(
10480
10486
10481
10487
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec (Value *V) {
10482
10488
const SCEV *Expr = this ->getSCEV (V);
10483
- auto *New = SE.convertSCEVToAddRecWithPredicates (Expr, &L, Preds);
10489
+ SmallPtrSet<const SCEVPredicate *, 4 > NewPreds;
10490
+ auto *New = SE.convertSCEVToAddRecWithPredicates (Expr, &L, NewPreds);
10484
10491
10485
10492
if (!New)
10486
10493
return nullptr ;
10487
10494
10495
+ for (auto *P : NewPreds)
10496
+ Preds.add (P);
10497
+
10488
10498
updateGeneration ();
10489
10499
RewriteMap[SE.getSCEV (V)] = {Generation, New};
10490
10500
return New;
0 commit comments