@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5706
5706
return true;
5707
5707
5708
5708
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5710
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5709
+ if (Expr1 != Expr2 &&
5710
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5711
5712
return false;
5712
5713
return true;
5713
5714
};
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14857
14858
bool addOverflowAssumption(const SCEVPredicate *P) {
14858
14859
if (!NewPreds) {
14859
14860
// Check if we've already made this assumption.
14860
- return Pred && Pred->implies(P);
14861
+ return Pred && Pred->implies(P, SE );
14861
14862
}
14862
14863
NewPreds->push_back(P);
14863
14864
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
14938
14939
assert(LHS != RHS && "LHS and RHS are the same SCEV");
14939
14940
}
14940
14941
14941
- bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14942
+ bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14943
+ ScalarEvolution &SE) const {
14942
14944
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14943
14945
14944
14946
if (!Op)
@@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
14968
14970
14969
14971
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14970
14972
14971
- bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14973
+ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14974
+ ScalarEvolution &SE) const {
14972
14975
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14976
+ if (!Op || setFlags(Flags, Op->Flags) != Flags)
14977
+ return false;
14978
+
14979
+ if (Op->AR == AR)
14980
+ return true;
14981
+
14982
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14983
+ Flags != SCEVWrapPredicate::IncrementNUSW)
14984
+ return false;
14973
14985
14974
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14986
+ const SCEV *Step = AR->getStepRecurrence(SE);
14987
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14988
+ if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14989
+ return false;
14990
+
14991
+ // If both steps are positive, this implies N, if N's start and step are
14992
+ // ULE/SLE (for NSUW/NSSW) than this'.
14993
+ Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14994
+ Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14995
+ OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14996
+
14997
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14998
+ const SCEV *OpStart = Op->AR->getStart();
14999
+ const SCEV *Start = AR->getStart();
15000
+ OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15001
+ : SE.getNoopOrSignExtend(OpStart, WiderTy);
15002
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15003
+ : SE.getNoopOrSignExtend(Start, WiderTy);
15004
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
15005
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
15006
+ SE.isKnownPredicate(Pred, OpStart, Start);
14975
15007
}
14976
15008
14977
15009
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,48 +15047,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
15015
15047
}
15016
15048
15017
15049
/// Union predicates don't get cached so create a dummy set ID for it.
15018
- SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
15019
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15050
+ SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15051
+ ScalarEvolution &SE)
15052
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15020
15053
for (const auto *P : Preds)
15021
- add(P);
15054
+ add(P, SE );
15022
15055
}
15023
15056
15024
15057
bool SCEVUnionPredicate::isAlwaysTrue() const {
15025
15058
return all_of(Preds,
15026
15059
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15027
15060
}
15028
15061
15029
- bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15062
+ bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15063
+ ScalarEvolution &SE) const {
15030
15064
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15031
- return all_of(Set->Preds,
15032
- [this](const SCEVPredicate *I) { return this->implies(I); });
15065
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15066
+ return this->implies(I, SE);
15067
+ });
15033
15068
15034
15069
return any_of(Preds,
15035
- [N](const SCEVPredicate *I) { return I->implies(N); });
15070
+ [N, &SE ](const SCEVPredicate *I) { return I->implies(N, SE ); });
15036
15071
}
15037
15072
15038
15073
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
15039
15074
for (const auto *Pred : Preds)
15040
15075
Pred->print(OS, Depth);
15041
15076
}
15042
15077
15043
- void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15078
+ void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE ) {
15044
15079
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15045
15080
for (const auto *Pred : Set->Preds)
15046
- add(Pred);
15081
+ add(Pred, SE );
15047
15082
return;
15048
15083
}
15049
15084
15050
15085
// Only add predicate if it is not already implied by this union predicate.
15051
- if (!implies(N))
15086
+ if (!implies(N, SE ))
15052
15087
Preds.push_back(N);
15053
15088
}
15054
15089
15055
15090
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
15056
15091
Loop &L)
15057
15092
: SE(SE), L(L) {
15058
15093
SmallVector<const SCEVPredicate*, 4> Empty;
15059
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15094
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE );
15060
15095
}
15061
15096
15062
15097
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
15120
15155
}
15121
15156
15122
15157
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15123
- if (Preds->implies(&Pred))
15158
+ if (Preds->implies(&Pred, SE ))
15124
15159
return;
15125
15160
15126
15161
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15127
15162
NewPreds.push_back(&Pred);
15128
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15163
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE );
15129
15164
updateGeneration();
15130
15165
}
15131
15166
@@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
15192
15227
15193
15228
PredicatedScalarEvolution::PredicatedScalarEvolution(
15194
15229
const PredicatedScalarEvolution &Init)
15195
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15196
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15197
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15230
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15231
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15232
+ SE)),
15233
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15198
15234
for (auto I : Init.FlagsMap)
15199
15235
FlagsMap.insert(I);
15200
15236
}
0 commit comments