@@ -10490,8 +10490,11 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10490
10490
if (!isLoopInvariant(Step, L))
10491
10491
return getCouldNotCompute();
10492
10492
10493
+ const auto &[RewriteMap, PreserveNUW, PreserveNSW] =
10494
+ collectRewriteInfoFromLoopGuards(L);
10493
10495
// Specialize step for this loop so we get context sensitive facts below.
10494
- const SCEV *StepWLG = applyLoopGuards(Step, L);
10496
+ const SCEV *StepWLG =
10497
+ applyLoopGuards(Step, L, RewriteMap, PreserveNUW, PreserveNSW);
10495
10498
10496
10499
// For positive steps (counting up until unsigned overflow):
10497
10500
// N = -Start/Step (as unsigned)
@@ -10508,7 +10511,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10508
10511
// N = Distance (as unsigned)
10509
10512
if (StepC &&
10510
10513
(StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
10511
- APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10514
+ APInt MaxBECount = getUnsignedRangeMax(
10515
+ applyLoopGuards(Distance, L, RewriteMap, PreserveNUW, PreserveNSW));
10512
10516
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10513
10517
10514
10518
// When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10549,7 +10553,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10549
10553
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10550
10554
const SCEV *ConstantMax = getCouldNotCompute();
10551
10555
if (Exact != getCouldNotCompute()) {
10552
- APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
10556
+ APInt MaxInt = getUnsignedRangeMax(
10557
+ applyLoopGuards(Exact, L, RewriteMap, PreserveNUW, PreserveNSW));
10553
10558
ConstantMax =
10554
10559
getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
10555
10560
}
@@ -10566,7 +10571,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10566
10571
10567
10572
const SCEV *M = E;
10568
10573
if (E != getCouldNotCompute()) {
10569
- APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10574
+ APInt MaxWithGuards = getUnsignedRangeMax(
10575
+ applyLoopGuards(E, L, RewriteMap, PreserveNUW, PreserveNSW));
10570
10576
M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10571
10577
}
10572
10578
auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -15096,7 +15102,7 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15096
15102
15097
15103
public:
15098
15104
SCEVLoopGuardRewriter(ScalarEvolution &SE,
15099
- DenseMap<const SCEV *, const SCEV *> &M,
15105
+ const DenseMap<const SCEV *, const SCEV *> &M,
15100
15106
bool PreserveNUW, bool PreserveNSW)
15101
15107
: SCEVRewriteVisitor(SE), Map(M) {
15102
15108
if (PreserveNUW)
@@ -15191,7 +15197,8 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15191
15197
}
15192
15198
};
15193
15199
15194
- const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15200
+ std::tuple<DenseMap<const SCEV *, const SCEV *>, bool, bool>
15201
+ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
15195
15202
SmallVector<const SCEV *> ExprsToRewrite;
15196
15203
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15197
15204
const SCEV *RHS,
@@ -15600,9 +15607,6 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15600
15607
}
15601
15608
}
15602
15609
15603
- if (RewriteMap.empty())
15604
- return Expr;
15605
-
15606
15610
// Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15607
15611
// the replacement expressions are contained in the ranges of the replaced
15608
15612
// expressions.
@@ -15626,6 +15630,22 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15626
15630
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15627
15631
}
15628
15632
}
15633
+ return {RewriteMap, PreserveNUW, PreserveNSW};
15634
+ }
15635
+
15636
+ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15637
+ const auto &[RewriteMap, PreserveNUW, PreserveNSW] =
15638
+ collectRewriteInfoFromLoopGuards(L);
15639
+ return applyLoopGuards(Expr, L, RewriteMap, PreserveNUW, PreserveNSW);
15640
+ }
15641
+
15642
+ const SCEV *ScalarEvolution::applyLoopGuards(
15643
+ const SCEV *Expr, const Loop *L,
15644
+ const DenseMap<const SCEV *, const SCEV *> &RewriteMap, bool PreserveNUW,
15645
+ bool PreserveNSW) {
15646
+ if (RewriteMap.empty())
15647
+ return Expr;
15648
+
15629
15649
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
15630
15650
return Rewriter.visit(Expr);
15631
15651
}
0 commit comments