Skip to content

Commit 7e59b20

Browse files
authored
[SCEV] Support addrec in right hand side in howManyLessThans (#92560)
Fixes #92554 (std::reverse will auto-vectorize now) When calculating number of times a exit condition containing a comparison is executed, we mostly assume that RHS of comparison should be loop invariant, but it may be another add-recurrence. ~In that case, we can try the computation with `LHS = LHS - RHS` and `RHS = 0`.~ (It is not valid unless proven that it doesn't wrap) **Edit:** We can calculate back edge count for loop structure like: ```cpp left = left_start right = right_start while(left < right){ // ...do something... left += s1; // the stride of left is s1 (> 0) right -= s2; // the stride of right is -s2 (s2 > 0) } // left and right converge somewhere in the middle of their start values ``` We can calculate the backedge-count as ceil((End - left_start) /u (s1- (-s2)) where, End = max(left_start, right_start). **Alive2**: https://alive2.llvm.org/ce/z/ggxx58
1 parent dff6871 commit 7e59b20

File tree

2 files changed

+346
-162
lines changed

2 files changed

+346
-162
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 210 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -13000,179 +13000,227 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
1300013000
return RHS;
1300113001
}
1300213002

13003-
// When the RHS is not invariant, we do not know the end bound of the loop and
13004-
// cannot calculate the ExactBECount needed by ExitLimit. However, we can
13005-
// calculate the MaxBECount, given the start, stride and max value for the end
13006-
// bound of the loop (RHS), and the fact that IV does not overflow (which is
13007-
// checked above).
13003+
const SCEV *End = nullptr, *BECount = nullptr,
13004+
*BECountIfBackedgeTaken = nullptr;
1300813005
if (!isLoopInvariant(RHS, L)) {
13009-
const SCEV *MaxBECount = computeMaxBECountForLT(
13010-
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13011-
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13012-
MaxBECount, false /*MaxOrZero*/, Predicates);
13013-
}
13014-
13015-
// We use the expression (max(End,Start)-Start)/Stride to describe the
13016-
// backedge count, as if the backedge is taken at least once max(End,Start)
13017-
// is End and so the result is as above, and if not max(End,Start) is Start
13018-
// so we get a backedge count of zero.
13019-
const SCEV *BECount = nullptr;
13020-
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13021-
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13022-
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13023-
assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13024-
// Can we prove (max(RHS,Start) > Start - Stride?
13025-
if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13026-
isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13027-
// In this case, we can use a refined formula for computing backedge taken
13028-
// count. The general formula remains:
13029-
// "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13030-
// We want to use the alternate formula:
13031-
// "((End - 1) - (Start - Stride)) /u Stride"
13032-
// Let's do a quick case analysis to show these are equivalent under
13033-
// our precondition that max(RHS,Start) > Start - Stride.
13034-
// * For RHS <= Start, the backedge-taken count must be zero.
13035-
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
13036-
// "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13037-
// "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13038-
// of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
13039-
// this to the stride of 1 case.
13040-
// * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
13041-
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
13042-
// "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13043-
// "((RHS - (Start - Stride) - 1) /u Stride".
13044-
// Our preconditions trivially imply no overflow in that form.
13045-
const SCEV *MinusOne = getMinusOne(Stride->getType());
13046-
const SCEV *Numerator =
13047-
getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13048-
BECount = getUDivExpr(Numerator, Stride);
13049-
}
13050-
13051-
const SCEV *BECountIfBackedgeTaken = nullptr;
13052-
if (!BECount) {
13053-
auto canProveRHSGreaterThanEqualStart = [&]() {
13054-
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13055-
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13056-
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13057-
13058-
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13059-
isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13060-
return true;
13061-
13062-
// (RHS > Start - 1) implies RHS >= Start.
13063-
// * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13064-
// "Start - 1" doesn't overflow.
13065-
// * For signed comparison, if Start - 1 does overflow, it's equal
13066-
// to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13067-
// * For unsigned comparison, if Start - 1 does overflow, it's equal
13068-
// to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13069-
//
13070-
// FIXME: Should isLoopEntryGuardedByCond do this for us?
13071-
auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13072-
auto *StartMinusOne = getAddExpr(OrigStart,
13073-
getMinusOne(OrigStart->getType()));
13074-
return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13075-
};
13076-
13077-
// If we know that RHS >= Start in the context of loop, then we know that
13078-
// max(RHS, Start) = RHS at this point.
13079-
const SCEV *End;
13080-
if (canProveRHSGreaterThanEqualStart()) {
13081-
End = RHS;
13082-
} else {
13083-
// If RHS < Start, the backedge will be taken zero times. So in
13084-
// general, we can write the backedge-taken count as:
13006+
const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13007+
if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13008+
RHSAddRec->getNoWrapFlags()) {
13009+
// The structure of loop we are trying to calculate backedge count of:
1308513010
//
13086-
// RHS >= Start ? ceil(RHS - Start) / Stride : 0
13011+
// left = left_start
13012+
// right = right_start
1308713013
//
13088-
// We convert it to the following to make it more convenient for SCEV:
13014+
// while(left < right){
13015+
// ... do something here ...
13016+
// left += s1; // stride of left is s1 (s1 > 0)
13017+
// right += s2; // stride of right is s2 (s2 < 0)
13018+
// }
1308913019
//
13090-
// ceil(max(RHS, Start) - Start) / Stride
13091-
End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
1309213020

13093-
// See what would happen if we assume the backedge is taken. This is
13094-
// used to compute MaxBECount.
13095-
BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13096-
}
13021+
const SCEV *RHSStart = RHSAddRec->getStart();
13022+
const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
1309713023

13098-
// At this point, we know:
13099-
//
13100-
// 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13101-
// 2. The index variable doesn't overflow.
13102-
//
13103-
// Therefore, we know N exists such that
13104-
// (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13105-
// doesn't overflow.
13106-
//
13107-
// Using this information, try to prove whether the addition in
13108-
// "(Start - End) + (Stride - 1)" has unsigned overflow.
13109-
const SCEV *One = getOne(Stride->getType());
13110-
bool MayAddOverflow = [&] {
13111-
if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13112-
if (StrideC->getAPInt().isPowerOf2()) {
13113-
// Suppose Stride is a power of two, and Start/End are unsigned
13114-
// integers. Let UMAX be the largest representable unsigned
13115-
// integer.
13116-
//
13117-
// By the preconditions of this function, we know
13118-
// "(Start + Stride * N) >= End", and this doesn't overflow.
13119-
// As a formula:
13120-
//
13121-
// End <= (Start + Stride * N) <= UMAX
13122-
//
13123-
// Subtracting Start from all the terms:
13124-
//
13125-
// End - Start <= Stride * N <= UMAX - Start
13126-
//
13127-
// Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13128-
//
13129-
// End - Start <= Stride * N <= UMAX
13130-
//
13131-
// Stride * N is a multiple of Stride. Therefore,
13132-
//
13133-
// End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13134-
//
13135-
// Since Stride is a power of two, UMAX + 1 is divisible by Stride.
13136-
// Therefore, UMAX mod Stride == Stride - 1. So we can write:
13137-
//
13138-
// End - Start <= Stride * N <= UMAX - Stride - 1
13139-
//
13140-
// Dropping the middle term:
13141-
//
13142-
// End - Start <= UMAX - Stride - 1
13143-
//
13144-
// Adding Stride - 1 to both sides:
13145-
//
13146-
// (End - Start) + (Stride - 1) <= UMAX
13147-
//
13148-
// In other words, the addition doesn't have unsigned overflow.
13149-
//
13150-
// A similar proof works if we treat Start/End as signed values.
13151-
// Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
13152-
// use signed max instead of unsigned max. Note that we're trying
13153-
// to prove a lack of unsigned overflow in either case.
13154-
return false;
13024+
// If Stride - RHSStride is positive and does not overflow, we can write
13025+
// backedge count as ->
13026+
// ceil((End - Start) /u (Stride - RHSStride))
13027+
// Where, End = max(RHSStart, Start)
13028+
13029+
// Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13030+
if (isKnownNegative(RHSStride) &&
13031+
willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13032+
RHSStride)) {
13033+
13034+
const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13035+
if (isKnownPositive(Denominator)) {
13036+
End = IsSigned ? getSMaxExpr(RHSStart, Start)
13037+
: getUMaxExpr(RHSStart, Start);
13038+
13039+
// We can do this because End >= Start, as End = max(RHSStart, Start)
13040+
const SCEV *Delta = getMinusSCEV(End, Start);
13041+
13042+
BECount = getUDivCeilSCEV(Delta, Denominator);
13043+
BECountIfBackedgeTaken =
13044+
getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
1315513045
}
1315613046
}
13157-
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13158-
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
13159-
// If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
13160-
// If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
13047+
}
13048+
if (BECount == nullptr) {
13049+
// If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13050+
// given the start, stride and max value for the end bound of the
13051+
// loop (RHS), and the fact that IV does not overflow (which is
13052+
// checked above).
13053+
const SCEV *MaxBECount = computeMaxBECountForLT(
13054+
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13055+
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13056+
MaxBECount, false /*MaxOrZero*/, Predicates);
13057+
}
13058+
} else {
13059+
// We use the expression (max(End,Start)-Start)/Stride to describe the
13060+
// backedge count, as if the backedge is taken at least once
13061+
// max(End,Start) is End and so the result is as above, and if not
13062+
// max(End,Start) is Start so we get a backedge count of zero.
13063+
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13064+
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13065+
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13066+
assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13067+
// Can we prove (max(RHS,Start) > Start - Stride?
13068+
if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13069+
isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13070+
// In this case, we can use a refined formula for computing backedge
13071+
// taken count. The general formula remains:
13072+
// "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13073+
// We want to use the alternate formula:
13074+
// "((End - 1) - (Start - Stride)) /u Stride"
13075+
// Let's do a quick case analysis to show these are equivalent under
13076+
// our precondition that max(RHS,Start) > Start - Stride.
13077+
// * For RHS <= Start, the backedge-taken count must be zero.
13078+
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
13079+
// "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13080+
// "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13081+
// of Stride. For 0 stride, we've use umin(1,Stride) above,
13082+
// reducing this to the stride of 1 case.
13083+
// * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13084+
// Stride".
13085+
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
13086+
// "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13087+
// "((RHS - (Start - Stride) - 1) /u Stride".
13088+
// Our preconditions trivially imply no overflow in that form.
13089+
const SCEV *MinusOne = getMinusOne(Stride->getType());
13090+
const SCEV *Numerator =
13091+
getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13092+
BECount = getUDivExpr(Numerator, Stride);
13093+
}
13094+
13095+
if (!BECount) {
13096+
auto canProveRHSGreaterThanEqualStart = [&]() {
13097+
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13098+
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13099+
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13100+
13101+
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13102+
isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13103+
return true;
13104+
13105+
// (RHS > Start - 1) implies RHS >= Start.
13106+
// * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13107+
// "Start - 1" doesn't overflow.
13108+
// * For signed comparison, if Start - 1 does overflow, it's equal
13109+
// to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13110+
// * For unsigned comparison, if Start - 1 does overflow, it's equal
13111+
// to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
1316113112
//
13162-
// If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
13163-
return false;
13113+
// FIXME: Should isLoopEntryGuardedByCond do this for us?
13114+
auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13115+
auto *StartMinusOne =
13116+
getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13117+
return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13118+
};
13119+
13120+
// If we know that RHS >= Start in the context of loop, then we know
13121+
// that max(RHS, Start) = RHS at this point.
13122+
if (canProveRHSGreaterThanEqualStart()) {
13123+
End = RHS;
13124+
} else {
13125+
// If RHS < Start, the backedge will be taken zero times. So in
13126+
// general, we can write the backedge-taken count as:
13127+
//
13128+
// RHS >= Start ? ceil(RHS - Start) / Stride : 0
13129+
//
13130+
// We convert it to the following to make it more convenient for SCEV:
13131+
//
13132+
// ceil(max(RHS, Start) - Start) / Stride
13133+
End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13134+
13135+
// See what would happen if we assume the backedge is taken. This is
13136+
// used to compute MaxBECount.
13137+
BECountIfBackedgeTaken =
13138+
getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
1316413139
}
13165-
return true;
13166-
}();
1316713140

13168-
const SCEV *Delta = getMinusSCEV(End, Start);
13169-
if (!MayAddOverflow) {
13170-
// floor((D + (S - 1)) / S)
13171-
// We prefer this formulation if it's legal because it's fewer operations.
13172-
BECount =
13173-
getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13174-
} else {
13175-
BECount = getUDivCeilSCEV(Delta, Stride);
13141+
// At this point, we know:
13142+
//
13143+
// 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13144+
// 2. The index variable doesn't overflow.
13145+
//
13146+
// Therefore, we know N exists such that
13147+
// (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13148+
// doesn't overflow.
13149+
//
13150+
// Using this information, try to prove whether the addition in
13151+
// "(Start - End) + (Stride - 1)" has unsigned overflow.
13152+
const SCEV *One = getOne(Stride->getType());
13153+
bool MayAddOverflow = [&] {
13154+
if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13155+
if (StrideC->getAPInt().isPowerOf2()) {
13156+
// Suppose Stride is a power of two, and Start/End are unsigned
13157+
// integers. Let UMAX be the largest representable unsigned
13158+
// integer.
13159+
//
13160+
// By the preconditions of this function, we know
13161+
// "(Start + Stride * N) >= End", and this doesn't overflow.
13162+
// As a formula:
13163+
//
13164+
// End <= (Start + Stride * N) <= UMAX
13165+
//
13166+
// Subtracting Start from all the terms:
13167+
//
13168+
// End - Start <= Stride * N <= UMAX - Start
13169+
//
13170+
// Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13171+
//
13172+
// End - Start <= Stride * N <= UMAX
13173+
//
13174+
// Stride * N is a multiple of Stride. Therefore,
13175+
//
13176+
// End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13177+
//
13178+
// Since Stride is a power of two, UMAX + 1 is divisible by
13179+
// Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13180+
// write:
13181+
//
13182+
// End - Start <= Stride * N <= UMAX - Stride - 1
13183+
//
13184+
// Dropping the middle term:
13185+
//
13186+
// End - Start <= UMAX - Stride - 1
13187+
//
13188+
// Adding Stride - 1 to both sides:
13189+
//
13190+
// (End - Start) + (Stride - 1) <= UMAX
13191+
//
13192+
// In other words, the addition doesn't have unsigned overflow.
13193+
//
13194+
// A similar proof works if we treat Start/End as signed values.
13195+
// Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13196+
// to use signed max instead of unsigned max. Note that we're
13197+
// trying to prove a lack of unsigned overflow in either case.
13198+
return false;
13199+
}
13200+
}
13201+
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13202+
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13203+
// - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13204+
// <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13205+
// 1 <s End.
13206+
//
13207+
// If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13208+
// End.
13209+
return false;
13210+
}
13211+
return true;
13212+
}();
13213+
13214+
const SCEV *Delta = getMinusSCEV(End, Start);
13215+
if (!MayAddOverflow) {
13216+
// floor((D + (S - 1)) / S)
13217+
// We prefer this formulation if it's legal because it's fewer
13218+
// operations.
13219+
BECount =
13220+
getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13221+
} else {
13222+
BECount = getUDivCeilSCEV(Delta, Stride);
13223+
}
1317613224
}
1317713225
}
1317813226

0 commit comments

Comments
 (0)