Skip to content

Commit 0a357ad

Browse files
authored
[SCEV] Support non-constant step in howFarToZero (#94411)
VF * vscale is the canonical step for a scalably vectorized loop, and LFTR canonicalizes to NE loop tests, so having our trip count logic be unable to compute trip counts for such loops is unfortunate. The existing code needed minimal generalization to handle non-constant strides. The tricky cases to be sure we handle correctly are: zero, and -1 (due to the special case of abs(-1) being non-positive). This patch does the full generalization in terms of code structure, but in practice, this seems unlikely to benefit anything beyond the (C * vscale) case. I did some quick investigation, and it seems the context free non-zero, and sign checks are basically never disproved for arbitrary scales. I think we have alternate tactics available for these, but I'm going to return to that in a separate patch.
1 parent f10e71f commit 0a357ad

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10483,29 +10483,26 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1048310483
// Get the initial value for the loop.
1048410484
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
1048510485
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10486-
10487-
// For now we handle only constant steps.
10488-
//
10489-
// TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10490-
// AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10491-
// to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10492-
// We have not yet seen any such cases.
1049310486
const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10494-
if (!StepC || StepC->getValue()->isZero())
10487+
10488+
if (!isLoopInvariant(Step, L) || !isKnownNonZero(Step))
1049510489
return getCouldNotCompute();
1049610490

1049710491
// For positive steps (counting up until unsigned overflow):
1049810492
// N = -Start/Step (as unsigned)
1049910493
// For negative steps (counting down to zero):
1050010494
// N = Start/-Step
1050110495
// First compute the unsigned distance from zero in the direction of Step.
10502-
bool CountDown = StepC->getAPInt().isNegative();
10503-
const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10496+
bool CountDown = isKnownNegative(Step);
10497+
if (!CountDown && !isKnownNonNegative(Step))
10498+
return getCouldNotCompute();
1050410499

10500+
const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
1050510501
// Handle unitary steps, which cannot wraparound.
1050610502
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
1050710503
// N = Distance (as unsigned)
10508-
if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10504+
if (StepC &&
10505+
(StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
1050910506
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
1051010507
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
1051110508

@@ -10550,6 +10547,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1055010547
}
1055110548

1055210549
// Solve the general equation.
10550+
if (!StepC)
10551+
return getCouldNotCompute();
1055310552
const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
1055410553
getNegativeSCEV(Start), *this);
1055510554

llvm/test/Analysis/ScalarEvolution/scalable-vector.ll

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@ define void @vscale_step_ne_tripcount(i64 %N) vscale_range(2, 1024) {
9191
; CHECK-NEXT: %n.vec = sub i64 %n.rnd.up, %n.mod.vf
9292
; CHECK-NEXT: --> (4 * vscale * ((-1 + (4 * vscale)<nuw><nsw> + %N) /u (4 * vscale)<nuw><nsw>)) U: [0,-3) S: [-9223372036854775808,9223372036854775805)
9393
; CHECK-NEXT: %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
94-
; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<nuw><%vector.body> U: [0,-3) S: [-9223372036854775808,9223372036854775805) Exits: <<Unknown>> LoopDispositions: { %vector.body: Computable }
94+
; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<nuw><%vector.body> U: [0,-3) S: [-9223372036854775808,9223372036854775805) Exits: (4 * vscale * ((-1 * vscale * (4 + (-4 * ((-1 + (4 * vscale)<nuw><nsw> + %N) /u (4 * vscale)<nuw><nsw>))<nsw>)<nsw>) /u (4 * vscale)<nuw><nsw>)) LoopDispositions: { %vector.body: Computable }
9595
; CHECK-NEXT: %index.next = add nuw i64 %index, %2
96-
; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nuw><%vector.body> U: [8,-3) S: [-9223372036854775808,9223372036854775805) Exits: <<Unknown>> LoopDispositions: { %vector.body: Computable }
96+
; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nuw><%vector.body> U: [8,-3) S: [-9223372036854775808,9223372036854775805) Exits: (vscale * (4 + (4 * ((-1 * vscale * (4 + (-4 * ((-1 + (4 * vscale)<nuw><nsw> + %N) /u (4 * vscale)<nuw><nsw>))<nsw>)<nsw>) /u (4 * vscale)<nuw><nsw>))<nuw><nsw>)<nuw>) LoopDispositions: { %vector.body: Computable }
9797
; CHECK-NEXT: Determining loop execution counts for: @vscale_step_ne_tripcount
98-
; CHECK-NEXT: Loop %vector.body: Unpredictable backedge-taken count.
99-
; CHECK-NEXT: Loop %vector.body: Unpredictable constant max backedge-taken count.
100-
; CHECK-NEXT: Loop %vector.body: Unpredictable symbolic max backedge-taken count.
98+
; CHECK-NEXT: Loop %vector.body: backedge-taken count is ((-1 * vscale * (4 + (-4 * ((-1 + (4 * vscale)<nuw><nsw> + %N) /u (4 * vscale)<nuw><nsw>))<nsw>)<nsw>) /u (4 * vscale)<nuw><nsw>)
99+
; CHECK-NEXT: Loop %vector.body: constant max backedge-taken count is i64 2305843009213693951
100+
; CHECK-NEXT: Loop %vector.body: symbolic max backedge-taken count is ((-1 * vscale * (4 + (-4 * ((-1 + (4 * vscale)<nuw><nsw> + %N) /u (4 * vscale)<nuw><nsw>))<nsw>)<nsw>) /u (4 * vscale)<nuw><nsw>)
101+
; CHECK-NEXT: Loop %vector.body: Trip multiple is 1
101102
;
102103
entry:
103104
%0 = sub i64 -1, %N

0 commit comments

Comments
 (0)