@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5706
5706
return true;
5707
5707
5708
5708
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5710
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5709
+ if (Expr1 != Expr2 &&
5710
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5711
5712
return false;
5712
5713
return true;
5713
5714
};
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14857
14858
bool addOverflowAssumption(const SCEVPredicate *P) {
14858
14859
if (!NewPreds) {
14859
14860
// Check if we've already made this assumption.
14860
- return Pred && Pred->implies(P);
14861
+ return Pred && Pred->implies(P, SE );
14861
14862
}
14862
14863
NewPreds->push_back(P);
14863
14864
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
14938
14939
assert(LHS != RHS && "LHS and RHS are the same SCEV");
14939
14940
}
14940
14941
14941
- bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14942
+ bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14943
+ ScalarEvolution &SE) const {
14942
14944
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14943
14945
14944
14946
if (!Op)
@@ -14968,10 +14970,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
14968
14970
14969
14971
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14970
14972
14971
- bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14973
+ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14974
+ ScalarEvolution &SE) const {
14972
14975
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14976
+ if (!Op)
14977
+ return false;
14978
+
14979
+ if (setFlags(Flags, Op->Flags) != Flags)
14980
+ return false;
14981
+
14982
+ if (Op->AR == AR)
14983
+ return true;
14984
+
14985
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14986
+ Flags != SCEVWrapPredicate::IncrementNUSW)
14987
+ return false;
14973
14988
14974
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14989
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14990
+ const SCEV *Step = AR->getStepRecurrence(SE);
14991
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14992
+
14993
+ // If both steps are positive, this implies N, if N's start and step are
14994
+ // ULE/SLE (for NSUW/NSSW) than this'.
14995
+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
14996
+ const SCEV *OpStart = Op->AR->getStart();
14997
+ const SCEV *Start = AR->getStart();
14998
+ if (SE.getTypeSizeInBits(Step->getType()) >
14999
+ SE.getTypeSizeInBits(OpStep->getType())) {
15000
+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
15001
+ } else {
15002
+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
15003
+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
15004
+ }
15005
+ if (SE.getTypeSizeInBits(Start->getType()) >
15006
+ SE.getTypeSizeInBits(OpStart->getType())) {
15007
+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
15008
+ : SE.getSignExtendExpr(OpStart, Start->getType());
15009
+ } else {
15010
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
15011
+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
15012
+ }
15013
+
15014
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
15015
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
15016
+ SE.isKnownPredicate(Pred, OpStart, Start);
15017
+ }
15018
+ return false;
14975
15019
}
14976
15020
14977
15021
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,48 +15059,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
15015
15059
}
15016
15060
15017
15061
/// Union predicates don't get cached so create a dummy set ID for it.
15018
- SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
15019
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15062
+ SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15063
+ ScalarEvolution &SE)
15064
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15020
15065
for (const auto *P : Preds)
15021
- add(P);
15066
+ add(P, SE );
15022
15067
}
15023
15068
15024
15069
bool SCEVUnionPredicate::isAlwaysTrue() const {
15025
15070
return all_of(Preds,
15026
15071
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15027
15072
}
15028
15073
15029
- bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15074
+ bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15075
+ ScalarEvolution &SE) const {
15030
15076
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15031
- return all_of(Set->Preds,
15032
- [this](const SCEVPredicate *I) { return this->implies(I); });
15077
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15078
+ return this->implies(I, SE);
15079
+ });
15033
15080
15034
15081
return any_of(Preds,
15035
- [N](const SCEVPredicate *I) { return I->implies(N); });
15082
+ [N, &SE ](const SCEVPredicate *I) { return I->implies(N, SE ); });
15036
15083
}
15037
15084
15038
15085
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
15039
15086
for (const auto *Pred : Preds)
15040
15087
Pred->print(OS, Depth);
15041
15088
}
15042
15089
15043
- void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15090
+ void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE ) {
15044
15091
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15045
15092
for (const auto *Pred : Set->Preds)
15046
- add(Pred);
15093
+ add(Pred, SE );
15047
15094
return;
15048
15095
}
15049
15096
15050
15097
// Only add predicate if it is not already implied by this union predicate.
15051
- if (!implies(N))
15098
+ if (!implies(N, SE ))
15052
15099
Preds.push_back(N);
15053
15100
}
15054
15101
15055
15102
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
15056
15103
Loop &L)
15057
15104
: SE(SE), L(L) {
15058
15105
SmallVector<const SCEVPredicate*, 4> Empty;
15059
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15106
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE );
15060
15107
}
15061
15108
15062
15109
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15167,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
15120
15167
}
15121
15168
15122
15169
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15123
- if (Preds->implies(&Pred))
15170
+ if (Preds->implies(&Pred, SE ))
15124
15171
return;
15125
15172
15126
15173
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15127
15174
NewPreds.push_back(&Pred);
15128
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15175
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE );
15129
15176
updateGeneration();
15130
15177
}
15131
15178
@@ -15192,9 +15239,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
15192
15239
15193
15240
PredicatedScalarEvolution::PredicatedScalarEvolution(
15194
15241
const PredicatedScalarEvolution &Init)
15195
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15196
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15197
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15242
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15243
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15244
+ SE)),
15245
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15198
15246
for (auto I : Init.FlagsMap)
15199
15247
FlagsMap.insert(I);
15200
15248
}
0 commit comments