@@ -10075,23 +10075,48 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges(
10075
10075
bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
10076
10076
const SCEV *LHS,
10077
10077
const SCEV *RHS) {
10078
- // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
10079
- // Return Y via OutY.
10080
- auto MatchBinaryAddToConst =
10081
- [this](const SCEV *Result, const SCEV *X, APInt &OutY,
10082
- SCEV::NoWrapFlags ExpectedFlags) {
10083
- const SCEV *NonConstOp, *ConstOp;
10084
- SCEV::NoWrapFlags FlagsPresent;
10085
-
10086
- if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
10087
- !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
10078
+ // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
10079
+ // C1 and C2 are constant integers. If either X or Y are not add expressions,
10080
+ // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
10081
+ // OutC1 and OutC2.
10082
+ auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
10083
+ APInt &OutC1, APInt &OutC2,
10084
+ SCEV::NoWrapFlags ExpectedFlags) {
10085
+ const SCEV *XNonConstOp, *XConstOp;
10086
+ const SCEV *YNonConstOp, *YConstOp;
10087
+ SCEV::NoWrapFlags XFlagsPresent;
10088
+ SCEV::NoWrapFlags YFlagsPresent;
10089
+
10090
+ if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
10091
+ XConstOp = getZero(X->getType());
10092
+ XNonConstOp = X;
10093
+ XFlagsPresent = ExpectedFlags;
10094
+ }
10095
+ if (!isa<SCEVConstant>(XConstOp) ||
10096
+ (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
10088
10097
return false;
10089
10098
10090
- OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
10091
- return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
10099
+ if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
10100
+ YConstOp = getZero(Y->getType());
10101
+ YNonConstOp = Y;
10102
+ YFlagsPresent = ExpectedFlags;
10103
+ }
10104
+
10105
+ if (!isa<SCEVConstant>(YConstOp) ||
10106
+ (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
10107
+ return false;
10108
+
10109
+ if (YNonConstOp != XNonConstOp)
10110
+ return false;
10111
+
10112
+ OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
10113
+ OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
10114
+
10115
+ return true;
10092
10116
};
10093
10117
10094
- APInt C;
10118
+ APInt C1;
10119
+ APInt C2;
10095
10120
10096
10121
switch (Pred) {
10097
10122
default:
@@ -10101,45 +10126,38 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
10101
10126
std::swap(LHS, RHS);
10102
10127
LLVM_FALLTHROUGH;
10103
10128
case ICmpInst::ICMP_SLE:
10104
- // X s<= (X + C )<nsw> if C >= 0
10105
- if (MatchBinaryAddToConst(RHS, LHS, C , SCEV::FlagNSW) && C.isNonNegative( ))
10129
+ // (X + C1)<nsw> s<= (X + C2 )<nsw> if C1 s<= C2.
10130
+ if (MatchBinaryAddToConst(LHS, RHS, C1, C2 , SCEV::FlagNSW) && C1.sle(C2 ))
10106
10131
return true;
10107
10132
10108
- // (X + C)<nsw> s<= X if C <= 0
10109
- if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
10110
- !C.isStrictlyPositive())
10111
- return true;
10112
10133
break;
10113
10134
10114
10135
case ICmpInst::ICMP_SGT:
10115
10136
std::swap(LHS, RHS);
10116
10137
LLVM_FALLTHROUGH;
10117
10138
case ICmpInst::ICMP_SLT:
10118
- // X s< (X + C)<nsw> if C > 0
10119
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
10120
- C.isStrictlyPositive())
10139
+ // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
10140
+ if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
10121
10141
return true;
10122
10142
10123
- // (X + C)<nsw> s< X if C < 0
10124
- if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
10125
- return true;
10126
10143
break;
10127
10144
10128
10145
case ICmpInst::ICMP_UGE:
10129
10146
std::swap(LHS, RHS);
10130
10147
LLVM_FALLTHROUGH;
10131
10148
case ICmpInst::ICMP_ULE:
10132
- // X u<= (X + C )<nuw> for any C
10133
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW))
10149
+ // (X + C1)<nuw> u<= (X + C2 )<nuw> for C1 u<= C2.
10150
+ if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2 ))
10134
10151
return true;
10152
+
10135
10153
break;
10136
10154
10137
10155
case ICmpInst::ICMP_UGT:
10138
10156
std::swap(LHS, RHS);
10139
10157
LLVM_FALLTHROUGH;
10140
10158
case ICmpInst::ICMP_ULT:
10141
- // X u< (X + C )<nuw> if C != 0
10142
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue( ))
10159
+ // (X + C1)<nuw> u< (X + C2 )<nuw> if C1 u< C2.
10160
+ if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2 ))
10143
10161
return true;
10144
10162
break;
10145
10163
}
0 commit comments