Skip to content

[SCEV] Support addrec in right hand side in howManyLessThans #92560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 210 additions & 162 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13000,179 +13000,227 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
return RHS;
}

// When the RHS is not invariant, we do not know the end bound of the loop and
// cannot calculate the ExactBECount needed by ExitLimit. However, we can
// calculate the MaxBECount, given the start, stride and max value for the end
// bound of the loop (RHS), and the fact that IV does not overflow (which is
// checked above).
const SCEV *End = nullptr, *BECount = nullptr,
*BECountIfBackedgeTaken = nullptr;
if (!isLoopInvariant(RHS, L)) {
const SCEV *MaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
MaxBECount, false /*MaxOrZero*/, Predicates);
}

// We use the expression (max(End,Start)-Start)/Stride to describe the
// backedge count, as if the backedge is taken at least once max(End,Start)
// is End and so the result is as above, and if not max(End,Start) is Start
// so we get a backedge count of zero.
const SCEV *BECount = nullptr;
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
// Can we prove (max(RHS,Start) > Start - Stride?
if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
// In this case, we can use a refined formula for computing backedge taken
// count. The general formula remains:
// "End-Start /uceiling Stride" where "End = max(RHS,Start)"
// We want to use the alternate formula:
// "((End - 1) - (Start - Stride)) /u Stride"
// Let's do a quick case analysis to show these are equivalent under
// our precondition that max(RHS,Start) > Start - Stride.
// * For RHS <= Start, the backedge-taken count must be zero.
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
// "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
// "Stride - 1 /u Stride" which is indeed zero for all non-zero values
// of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
// this to the stride of 1 case.
// * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
// "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
// "((RHS - (Start - Stride) - 1) /u Stride".
// Our preconditions trivially imply no overflow in that form.
const SCEV *MinusOne = getMinusOne(Stride->getType());
const SCEV *Numerator =
getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
BECount = getUDivExpr(Numerator, Stride);
}

const SCEV *BECountIfBackedgeTaken = nullptr;
if (!BECount) {
auto canProveRHSGreaterThanEqualStart = [&]() {
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);

if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
return true;

// (RHS > Start - 1) implies RHS >= Start.
// * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
// "Start - 1" doesn't overflow.
// * For signed comparison, if Start - 1 does overflow, it's equal
// to INT_MAX, and "RHS >s INT_MAX" is trivially false.
// * For unsigned comparison, if Start - 1 does overflow, it's equal
// to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
//
// FIXME: Should isLoopEntryGuardedByCond do this for us?
auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
auto *StartMinusOne = getAddExpr(OrigStart,
getMinusOne(OrigStart->getType()));
return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
};

// If we know that RHS >= Start in the context of loop, then we know that
// max(RHS, Start) = RHS at this point.
const SCEV *End;
if (canProveRHSGreaterThanEqualStart()) {
End = RHS;
} else {
// If RHS < Start, the backedge will be taken zero times. So in
// general, we can write the backedge-taken count as:
const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
RHSAddRec->getNoWrapFlags()) {
// The structure of loop we are trying to calculate backedge count of:
//
// RHS >= Start ? ceil(RHS - Start) / Stride : 0
// left = left_start
// right = right_start
//
// We convert it to the following to make it more convenient for SCEV:
// while(left < right){
// ... do something here ...
// left += s1; // stride of left is s1 (s1 > 0)
// right += s2; // stride of right is s2 (s2 < 0)
// }
//
// ceil(max(RHS, Start) - Start) / Stride
End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);

// See what would happen if we assume the backedge is taken. This is
// used to compute MaxBECount.
BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
}
const SCEV *RHSStart = RHSAddRec->getStart();
const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);

// At this point, we know:
//
// 1. If IsSigned, Start <=s End; otherwise, Start <=u End
// 2. The index variable doesn't overflow.
//
// Therefore, we know N exists such that
// (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
// doesn't overflow.
//
// Using this information, try to prove whether the addition in
// "(Start - End) + (Stride - 1)" has unsigned overflow.
const SCEV *One = getOne(Stride->getType());
bool MayAddOverflow = [&] {
if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
if (StrideC->getAPInt().isPowerOf2()) {
// Suppose Stride is a power of two, and Start/End are unsigned
// integers. Let UMAX be the largest representable unsigned
// integer.
//
// By the preconditions of this function, we know
// "(Start + Stride * N) >= End", and this doesn't overflow.
// As a formula:
//
// End <= (Start + Stride * N) <= UMAX
//
// Subtracting Start from all the terms:
//
// End - Start <= Stride * N <= UMAX - Start
//
// Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
//
// End - Start <= Stride * N <= UMAX
//
// Stride * N is a multiple of Stride. Therefore,
//
// End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
//
// Since Stride is a power of two, UMAX + 1 is divisible by Stride.
// Therefore, UMAX mod Stride == Stride - 1. So we can write:
//
// End - Start <= Stride * N <= UMAX - Stride - 1
//
// Dropping the middle term:
//
// End - Start <= UMAX - Stride - 1
//
// Adding Stride - 1 to both sides:
//
// (End - Start) + (Stride - 1) <= UMAX
//
// In other words, the addition doesn't have unsigned overflow.
//
// A similar proof works if we treat Start/End as signed values.
// Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
// use signed max instead of unsigned max. Note that we're trying
// to prove a lack of unsigned overflow in either case.
return false;
// If Stride - RHSStride is positive and does not overflow, we can write
// backedge count as ->
// ceil((End - Start) /u (Stride - RHSStride))
// Where, End = max(RHSStart, Start)

// Check if RHSStride < 0 and Stride - RHSStride will not overflow.
if (isKnownNegative(RHSStride) &&
willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
RHSStride)) {

const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
if (isKnownPositive(Denominator)) {
End = IsSigned ? getSMaxExpr(RHSStart, Start)
: getUMaxExpr(RHSStart, Start);

// We can do this because End >= Start, as End = max(RHSStart, Start)
const SCEV *Delta = getMinusSCEV(End, Start);

BECount = getUDivCeilSCEV(Delta, Denominator);
BECountIfBackedgeTaken =
getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
}
}
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
// If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
// If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
}
if (BECount == nullptr) {
// If we cannot calculate ExactBECount, we can calculate the MaxBECount,
// given the start, stride and max value for the end bound of the
// loop (RHS), and the fact that IV does not overflow (which is
// checked above).
const SCEV *MaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
MaxBECount, false /*MaxOrZero*/, Predicates);
}
} else {
// We use the expression (max(End,Start)-Start)/Stride to describe the
// backedge count, as if the backedge is taken at least once
// max(End,Start) is End and so the result is as above, and if not
// max(End,Start) is Start so we get a backedge count of zero.
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
// Can we prove (max(RHS,Start) > Start - Stride?
if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
// In this case, we can use a refined formula for computing backedge
// taken count. The general formula remains:
// "End-Start /uceiling Stride" where "End = max(RHS,Start)"
// We want to use the alternate formula:
// "((End - 1) - (Start - Stride)) /u Stride"
// Let's do a quick case analysis to show these are equivalent under
// our precondition that max(RHS,Start) > Start - Stride.
// * For RHS <= Start, the backedge-taken count must be zero.
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
// "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
// "Stride - 1 /u Stride" which is indeed zero for all non-zero values
// of Stride. For 0 stride, we've use umin(1,Stride) above,
// reducing this to the stride of 1 case.
// * For RHS >= Start, the backedge count must be "RHS-Start /uceil
// Stride".
// "((End - 1) - (Start - Stride)) /u Stride" reduces to
// "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
// "((RHS - (Start - Stride) - 1) /u Stride".
// Our preconditions trivially imply no overflow in that form.
const SCEV *MinusOne = getMinusOne(Stride->getType());
const SCEV *Numerator =
getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
BECount = getUDivExpr(Numerator, Stride);
}

if (!BECount) {
auto canProveRHSGreaterThanEqualStart = [&]() {
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);

if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
return true;

// (RHS > Start - 1) implies RHS >= Start.
// * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
// "Start - 1" doesn't overflow.
// * For signed comparison, if Start - 1 does overflow, it's equal
// to INT_MAX, and "RHS >s INT_MAX" is trivially false.
// * For unsigned comparison, if Start - 1 does overflow, it's equal
// to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
//
// If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
return false;
// FIXME: Should isLoopEntryGuardedByCond do this for us?
auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
auto *StartMinusOne =
getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
};

// If we know that RHS >= Start in the context of loop, then we know
// that max(RHS, Start) = RHS at this point.
if (canProveRHSGreaterThanEqualStart()) {
End = RHS;
} else {
// If RHS < Start, the backedge will be taken zero times. So in
// general, we can write the backedge-taken count as:
//
// RHS >= Start ? ceil(RHS - Start) / Stride : 0
//
// We convert it to the following to make it more convenient for SCEV:
//
// ceil(max(RHS, Start) - Start) / Stride
End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);

// See what would happen if we assume the backedge is taken. This is
// used to compute MaxBECount.
BECountIfBackedgeTaken =
getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
}
return true;
}();

const SCEV *Delta = getMinusSCEV(End, Start);
if (!MayAddOverflow) {
// floor((D + (S - 1)) / S)
// We prefer this formulation if it's legal because it's fewer operations.
BECount =
getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
} else {
BECount = getUDivCeilSCEV(Delta, Stride);
// At this point, we know:
//
// 1. If IsSigned, Start <=s End; otherwise, Start <=u End
// 2. The index variable doesn't overflow.
//
// Therefore, we know N exists such that
// (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
// doesn't overflow.
//
// Using this information, try to prove whether the addition in
// "(Start - End) + (Stride - 1)" has unsigned overflow.
const SCEV *One = getOne(Stride->getType());
bool MayAddOverflow = [&] {
if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
if (StrideC->getAPInt().isPowerOf2()) {
// Suppose Stride is a power of two, and Start/End are unsigned
// integers. Let UMAX be the largest representable unsigned
// integer.
//
// By the preconditions of this function, we know
// "(Start + Stride * N) >= End", and this doesn't overflow.
// As a formula:
//
// End <= (Start + Stride * N) <= UMAX
//
// Subtracting Start from all the terms:
//
// End - Start <= Stride * N <= UMAX - Start
//
// Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
//
// End - Start <= Stride * N <= UMAX
//
// Stride * N is a multiple of Stride. Therefore,
//
// End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
//
// Since Stride is a power of two, UMAX + 1 is divisible by
// Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
// write:
//
// End - Start <= Stride * N <= UMAX - Stride - 1
//
// Dropping the middle term:
//
// End - Start <= UMAX - Stride - 1
//
// Adding Stride - 1 to both sides:
//
// (End - Start) + (Stride - 1) <= UMAX
//
// In other words, the addition doesn't have unsigned overflow.
//
// A similar proof works if we treat Start/End as signed values.
// Just rewrite steps before "End - Start <= Stride * N <= UMAX"
// to use signed max instead of unsigned max. Note that we're
// trying to prove a lack of unsigned overflow in either case.
return false;
}
}
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End
// - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
// <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
// 1 <s End.
//
// If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
// End.
return false;
}
return true;
}();

const SCEV *Delta = getMinusSCEV(End, Start);
if (!MayAddOverflow) {
// floor((D + (S - 1)) / S)
// We prefer this formulation if it's legal because it's fewer
// operations.
BECount =
getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
} else {
BECount = getUDivCeilSCEV(Delta, Stride);
}
}
}

Expand Down
Loading
Loading