Skip to content

Commit 60c1744

Browse files
mrdaybirdhiraditya
authored andcommitted
Update howManyLessThans
1 parent 2861ab8 commit 60c1744

File tree

1 file changed

+64
-16
lines changed

1 file changed

+64
-16
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13000,27 +13000,77 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
1300013000
return RHS;
1300113001
}
1300213002

13003+
const SCEV *End = nullptr, *BECount = nullptr,
13004+
*BECountIfBackedgeTaken = nullptr;
1300313005
if (!isLoopInvariant(RHS, L)) {
13004-
// If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
13005-
if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
13006-
return howManyLessThans(getMinusSCEV(IV, RHSAddRec),
13007-
getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
13008-
}
13009-
// If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13010-
// given the start, stride and max value for the end bound of the
13011-
// loop (RHS), and the fact that IV does not overflow (which is
13012-
// checked above).
13013-
const SCEV *MaxBECount = computeMaxBECountForLT(
13014-
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13015-
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13016-
MaxBECount, false /*MaxOrZero*/, Predicates);
13006+
if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
13007+
/*
13008+
The structure of loop we are trying to calculate backedge-count of:
13009+
left = left_start
13010+
right = right_start
13011+
while(left < right){
13012+
// ... do something here ...
13013+
left += s1; // stride of left is s1>0
13014+
right -= s2; // stride of right is -s2 (s2 > 0)
13015+
}
13016+
// left and right are converging at the middle
13017+
// (maybe not exactly at center)
13018+
13019+
*/
13020+
const SCEV *RHSStart = RHSAddRec->getStart();
13021+
const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13022+
// if Stride-RHSStride>0 and does not overflow, we can write
13023+
// backedge count as:
13024+
// RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0
13025+
13026+
// check if Stride-RHSStride will not overflow
13027+
if (willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
13028+
const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13029+
if (isKnownPositive(Denominator)) {
13030+
End = IsSigned ? getSMaxExpr(RHSStart, Start) :
13031+
getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
13032+
13033+
const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
13034+
13035+
BECount = getUDivCeilSCEV(Delta, Denominator);
13036+
BECountIfBackedgeTaken = getUDivCeilSCEV(
13037+
getMinusSCEV(RHSStart, Start), Denominator);
13038+
13039+
const SCEV *ConstantMaxBECount;
13040+
bool MaxOrZero = false;
13041+
if (isa<SCEVConstant>(BECount)) {
13042+
ConstantMaxBECount = BECount;
13043+
} else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13044+
ConstantMaxBECount = BECountIfBackedgeTaken;
13045+
MaxOrZero = true;
13046+
} else {
13047+
ConstantMaxBECount = computeMaxBECountForLT(
13048+
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()),
13049+
IsSigned);
13050+
}
13051+
13052+
const SCEV *SymbolicMaxBECount = BECount;
13053+
return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount,
13054+
MaxOrZero, Predicates);
13055+
}
13056+
}
13057+
}
13058+
if (BECount == nullptr) {
13059+
// If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13060+
// given the start, stride and max value for the end bound of the
13061+
// loop (RHS), and the fact that IV does not overflow (which is
13062+
// checked above).
13063+
const SCEV *MaxBECount = computeMaxBECountForLT(
13064+
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13065+
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13066+
MaxBECount, false /*MaxOrZero*/, Predicates);
13067+
}
1301713068
}
1301813069

1301913070
// We use the expression (max(End,Start)-Start)/Stride to describe the
1302013071
// backedge count, as if the backedge is taken at least once max(End,Start)
1302113072
// is End and so the result is as above, and if not max(End,Start) is Start
1302213073
// so we get a backedge count of zero.
13023-
const SCEV *BECount = nullptr;
1302413074
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
1302513075
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
1302613076
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
@@ -13052,7 +13102,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
1305213102
BECount = getUDivExpr(Numerator, Stride);
1305313103
}
1305413104

13055-
const SCEV *BECountIfBackedgeTaken = nullptr;
1305613105
if (!BECount) {
1305713106
auto canProveRHSGreaterThanEqualStart = [&]() {
1305813107
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
@@ -13080,7 +13129,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
1308013129

1308113130
// If we know that RHS >= Start in the context of loop, then we know that
1308213131
// max(RHS, Start) = RHS at this point.
13083-
const SCEV *End;
1308413132
if (canProveRHSGreaterThanEqualStart()) {
1308513133
End = RHS;
1308613134
} else {

0 commit comments

Comments
 (0)