@@ -13000,27 +13000,77 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13000
13000
return RHS;
13001
13001
}
13002
13002
13003
+ const SCEV *End = nullptr, *BECount = nullptr,
13004
+ *BECountIfBackedgeTaken = nullptr;
13003
13005
if (!isLoopInvariant(RHS, L)) {
13004
- // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
13005
- if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
13006
- return howManyLessThans(getMinusSCEV(IV, RHSAddRec),
13007
- getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
13008
- }
13009
- // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13010
- // given the start, stride and max value for the end bound of the
13011
- // loop (RHS), and the fact that IV does not overflow (which is
13012
- // checked above).
13013
- const SCEV *MaxBECount = computeMaxBECountForLT(
13014
- Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13015
- return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13016
- MaxBECount, false /*MaxOrZero*/, Predicates);
13006
+ if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
13007
+ /*
13008
+ The structure of loop we are trying to calculate backedge-count of:
13009
+ left = left_start
13010
+ right = right_start
13011
+ while(left < right){
13012
+ // ... do something here ...
13013
+ left += s1; // stride of left is s1>0
13014
+ right -= s2; // stride of right is -s2 (s2 > 0)
13015
+ }
13016
+ // left and right are converging at the middle
13017
+ // (maybe not exactly at center)
13018
+
13019
+ */
13020
+ const SCEV *RHSStart = RHSAddRec->getStart();
13021
+ const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13022
+ // if Stride-RHSStride>0 and does not overflow, we can write
13023
+ // backedge count as:
13024
+ // RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0
13025
+
13026
+ // check if Stride-RHSStride will not overflow
13027
+ if (willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
13028
+ const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13029
+ if (isKnownPositive(Denominator)) {
13030
+ End = IsSigned ? getSMaxExpr(RHSStart, Start) :
13031
+ getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
13032
+
13033
+ const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
13034
+
13035
+ BECount = getUDivCeilSCEV(Delta, Denominator);
13036
+ BECountIfBackedgeTaken = getUDivCeilSCEV(
13037
+ getMinusSCEV(RHSStart, Start), Denominator);
13038
+
13039
+ const SCEV *ConstantMaxBECount;
13040
+ bool MaxOrZero = false;
13041
+ if (isa<SCEVConstant>(BECount)) {
13042
+ ConstantMaxBECount = BECount;
13043
+ } else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13044
+ ConstantMaxBECount = BECountIfBackedgeTaken;
13045
+ MaxOrZero = true;
13046
+ } else {
13047
+ ConstantMaxBECount = computeMaxBECountForLT(
13048
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()),
13049
+ IsSigned);
13050
+ }
13051
+
13052
+ const SCEV *SymbolicMaxBECount = BECount;
13053
+ return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount,
13054
+ MaxOrZero, Predicates);
13055
+ }
13056
+ }
13057
+ }
13058
+ if (BECount == nullptr) {
13059
+ // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13060
+ // given the start, stride and max value for the end bound of the
13061
+ // loop (RHS), and the fact that IV does not overflow (which is
13062
+ // checked above).
13063
+ const SCEV *MaxBECount = computeMaxBECountForLT(
13064
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13065
+ return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13066
+ MaxBECount, false /*MaxOrZero*/, Predicates);
13067
+ }
13017
13068
}
13018
13069
13019
13070
// We use the expression (max(End,Start)-Start)/Stride to describe the
13020
13071
// backedge count, as if the backedge is taken at least once max(End,Start)
13021
13072
// is End and so the result is as above, and if not max(End,Start) is Start
13022
13073
// so we get a backedge count of zero.
13023
- const SCEV *BECount = nullptr;
13024
13074
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13025
13075
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13026
13076
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
@@ -13052,7 +13102,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13052
13102
BECount = getUDivExpr(Numerator, Stride);
13053
13103
}
13054
13104
13055
- const SCEV *BECountIfBackedgeTaken = nullptr;
13056
13105
if (!BECount) {
13057
13106
auto canProveRHSGreaterThanEqualStart = [&]() {
13058
13107
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
@@ -13080,7 +13129,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13080
13129
13081
13130
// If we know that RHS >= Start in the context of loop, then we know that
13082
13131
// max(RHS, Start) = RHS at this point.
13083
- const SCEV *End;
13084
13132
if (canProveRHSGreaterThanEqualStart()) {
13085
13133
End = RHS;
13086
13134
} else {
0 commit comments