Skip to content

Commit 121ecb0

Browse files
committed
[SCEV] Generalize MatchBinaryAddToConst to support non-add expressions.
This patch generalizes MatchBinaryAddToConst to support matching (A + C1), (A + C2), instead of just matching (A + C1), A. The existing cases can be handled by treating non-add expressions A as A + 0. Reviewed By: mkazantsev Differential Revision: https://reviews.llvm.org/D104634
1 parent 0c4651f commit 121ecb0

File tree

1 file changed

+47
-29
lines changed

1 file changed

+47
-29
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10075,23 +10075,48 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges(
1007510075
bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
1007610076
const SCEV *LHS,
1007710077
const SCEV *RHS) {
10078-
// Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
10079-
// Return Y via OutY.
10080-
auto MatchBinaryAddToConst =
10081-
[this](const SCEV *Result, const SCEV *X, APInt &OutY,
10082-
SCEV::NoWrapFlags ExpectedFlags) {
10083-
const SCEV *NonConstOp, *ConstOp;
10084-
SCEV::NoWrapFlags FlagsPresent;
10085-
10086-
if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
10087-
!isa<SCEVConstant>(ConstOp) || NonConstOp != X)
10078+
// Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
10079+
// C1 and C2 are constant integers. If either X or Y are not add expressions,
10080+
// consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
10081+
// OutC1 and OutC2.
10082+
auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
10083+
APInt &OutC1, APInt &OutC2,
10084+
SCEV::NoWrapFlags ExpectedFlags) {
10085+
const SCEV *XNonConstOp, *XConstOp;
10086+
const SCEV *YNonConstOp, *YConstOp;
10087+
SCEV::NoWrapFlags XFlagsPresent;
10088+
SCEV::NoWrapFlags YFlagsPresent;
10089+
10090+
if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
10091+
XConstOp = getZero(X->getType());
10092+
XNonConstOp = X;
10093+
XFlagsPresent = ExpectedFlags;
10094+
}
10095+
if (!isa<SCEVConstant>(XConstOp) ||
10096+
(XFlagsPresent & ExpectedFlags) != ExpectedFlags)
1008810097
return false;
1008910098

10090-
OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
10091-
return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
10099+
if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
10100+
YConstOp = getZero(Y->getType());
10101+
YNonConstOp = Y;
10102+
YFlagsPresent = ExpectedFlags;
10103+
}
10104+
10105+
if (!isa<SCEVConstant>(YConstOp) ||
10106+
(YFlagsPresent & ExpectedFlags) != ExpectedFlags)
10107+
return false;
10108+
10109+
if (YNonConstOp != XNonConstOp)
10110+
return false;
10111+
10112+
OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
10113+
OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
10114+
10115+
return true;
1009210116
};
1009310117

10094-
APInt C;
10118+
APInt C1;
10119+
APInt C2;
1009510120

1009610121
switch (Pred) {
1009710122
default:
@@ -10101,45 +10126,38 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
1010110126
std::swap(LHS, RHS);
1010210127
LLVM_FALLTHROUGH;
1010310128
case ICmpInst::ICMP_SLE:
10104-
// X s<= (X + C)<nsw> if C >= 0
10105-
if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
10129+
// (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
10130+
if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
1010610131
return true;
1010710132

10108-
// (X + C)<nsw> s<= X if C <= 0
10109-
if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
10110-
!C.isStrictlyPositive())
10111-
return true;
1011210133
break;
1011310134

1011410135
case ICmpInst::ICMP_SGT:
1011510136
std::swap(LHS, RHS);
1011610137
LLVM_FALLTHROUGH;
1011710138
case ICmpInst::ICMP_SLT:
10118-
// X s< (X + C)<nsw> if C > 0
10119-
if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
10120-
C.isStrictlyPositive())
10139+
// (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
10140+
if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
1012110141
return true;
1012210142

10123-
// (X + C)<nsw> s< X if C < 0
10124-
if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
10125-
return true;
1012610143
break;
1012710144

1012810145
case ICmpInst::ICMP_UGE:
1012910146
std::swap(LHS, RHS);
1013010147
LLVM_FALLTHROUGH;
1013110148
case ICmpInst::ICMP_ULE:
10132-
// X u<= (X + C)<nuw> for any C
10133-
if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW))
10149+
// (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
10150+
if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
1013410151
return true;
10152+
1013510153
break;
1013610154

1013710155
case ICmpInst::ICMP_UGT:
1013810156
std::swap(LHS, RHS);
1013910157
LLVM_FALLTHROUGH;
1014010158
case ICmpInst::ICMP_ULT:
10141-
// X u< (X + C)<nuw> if C != 0
10142-
if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue())
10159+
// (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
10160+
if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
1014310161
return true;
1014410162
break;
1014510163
}

0 commit comments

Comments
 (0)