@@ -5725,8 +5725,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5725
5725
return true;
5726
5726
5727
5727
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5728
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5729
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5728
+ if (Expr1 != Expr2 &&
5729
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5730
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5730
5731
return false;
5731
5732
return true;
5732
5733
};
@@ -14823,7 +14824,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14823
14824
bool addOverflowAssumption(const SCEVPredicate *P) {
14824
14825
if (!NewPreds) {
14825
14826
// Check if we've already made this assumption.
14826
- return Pred && Pred->implies(P);
14827
+ return Pred && Pred->implies(P, SE );
14827
14828
}
14828
14829
NewPreds->push_back(P);
14829
14830
return true;
@@ -14904,7 +14905,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
14904
14905
assert(LHS != RHS && "LHS and RHS are the same SCEV");
14905
14906
}
14906
14907
14907
- bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14908
+ bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14909
+ ScalarEvolution &SE) const {
14908
14910
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14909
14911
14910
14912
if (!Op)
@@ -14934,10 +14936,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
14934
14936
14935
14937
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14936
14938
14937
- bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14939
+ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14940
+ ScalarEvolution &SE) const {
14938
14941
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14942
+ if (!Op)
14943
+ return false;
14944
+
14945
+ if (setFlags(Flags, Op->Flags) != Flags)
14946
+ return false;
14947
+
14948
+ if (Op->AR == AR)
14949
+ return true;
14950
+
14951
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14952
+ Flags != SCEVWrapPredicate::IncrementNUSW)
14953
+ return false;
14939
14954
14940
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14955
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14956
+ const SCEV *Step = AR->getStepRecurrence(SE);
14957
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14958
+
14959
+ // If both steps are positive, this implies N, if N's start and step are
14960
+ // ULE/SLE (for NSUW/NSSW) than this'.
14961
+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
14962
+ const SCEV *OpStart = Op->AR->getStart();
14963
+ const SCEV *Start = AR->getStart();
14964
+ if (SE.getTypeSizeInBits(Step->getType()) >
14965
+ SE.getTypeSizeInBits(OpStep->getType())) {
14966
+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
14967
+ } else {
14968
+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
14969
+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
14970
+ }
14971
+ if (SE.getTypeSizeInBits(Start->getType()) >
14972
+ SE.getTypeSizeInBits(OpStart->getType())) {
14973
+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
14974
+ : SE.getSignExtendExpr(OpStart, Start->getType());
14975
+ } else {
14976
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
14977
+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
14978
+ }
14979
+
14980
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
14981
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
14982
+ SE.isKnownPredicate(Pred, OpStart, Start);
14983
+ }
14984
+ return false;
14941
14985
}
14942
14986
14943
14987
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -14981,48 +15025,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
14981
15025
}
14982
15026
14983
15027
/// Union predicates don't get cached so create a dummy set ID for it.
14984
- SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
14985
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15028
+ SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15029
+ ScalarEvolution &SE)
15030
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14986
15031
for (const auto *P : Preds)
14987
- add(P);
15032
+ add(P, SE );
14988
15033
}
14989
15034
14990
15035
bool SCEVUnionPredicate::isAlwaysTrue() const {
14991
15036
return all_of(Preds,
14992
15037
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14993
15038
}
14994
15039
14995
- bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15040
+ bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15041
+ ScalarEvolution &SE) const {
14996
15042
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14997
- return all_of(Set->Preds,
14998
- [this](const SCEVPredicate *I) { return this->implies(I); });
15043
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15044
+ return this->implies(I, SE);
15045
+ });
14999
15046
15000
15047
return any_of(Preds,
15001
- [N](const SCEVPredicate *I) { return I->implies(N); });
15048
+ [N, &SE ](const SCEVPredicate *I) { return I->implies(N, SE ); });
15002
15049
}
15003
15050
15004
15051
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
15005
15052
for (const auto *Pred : Preds)
15006
15053
Pred->print(OS, Depth);
15007
15054
}
15008
15055
15009
- void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15056
+ void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE ) {
15010
15057
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15011
15058
for (const auto *Pred : Set->Preds)
15012
- add(Pred);
15059
+ add(Pred, SE );
15013
15060
return;
15014
15061
}
15015
15062
15016
15063
// Only add predicate if it is not already implied by this union predicate.
15017
- if (!implies(N))
15064
+ if (!implies(N, SE ))
15018
15065
Preds.push_back(N);
15019
15066
}
15020
15067
15021
15068
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
15022
15069
Loop &L)
15023
15070
: SE(SE), L(L) {
15024
15071
SmallVector<const SCEVPredicate*, 4> Empty;
15025
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15072
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE );
15026
15073
}
15027
15074
15028
15075
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15086,12 +15133,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
15086
15133
}
15087
15134
15088
15135
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15089
- if (Preds->implies(&Pred))
15136
+ if (Preds->implies(&Pred, SE ))
15090
15137
return;
15091
15138
15092
15139
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15093
15140
NewPreds.push_back(&Pred);
15094
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15141
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE );
15095
15142
updateGeneration();
15096
15143
}
15097
15144
@@ -15158,9 +15205,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
15158
15205
15159
15206
PredicatedScalarEvolution::PredicatedScalarEvolution(
15160
15207
const PredicatedScalarEvolution &Init)
15161
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15162
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15163
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15208
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15209
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15210
+ SE)),
15211
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15164
15212
for (auto I : Init.FlagsMap)
15165
15213
FlagsMap.insert(I);
15166
15214
}
0 commit comments