Skip to content

Commit 7583c48

Browse files
preamesnikic
andauthored
[SCEV] Use power of two facts involving vscale when inferring wrap flags (#101380)
SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops. The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason. --------- Co-authored-by: Nikita Popov <[email protected]>
1 parent 9effefb commit 7583c48

File tree

3 files changed

+90
-74
lines changed

3 files changed

+90
-74
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,9 @@ class ScalarEvolution {
10281028
/// Test if the given expression is known to be non-zero.
10291029
bool isKnownNonZero(const SCEV *S);
10301030

1031+
/// Test if the given expression is known to be a power of 2.
1032+
bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero = false);
1033+
10311034
/// Splits SCEV expression \p S into two SCEVs. One of them is obtained from
10321035
/// \p S by substitution of all AddRec sub-expression related to loop \p L
10331036
/// with initial value of that SCEV. The second is obtained from \p S by

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 71 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9149,23 +9149,21 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
91499149
// behaviour), and we can prove the test sequence produced must repeat
91509150
// the same values on self-wrap of the IV, then we can infer that IV
91519151
// doesn't self wrap because if it did, we'd have an infinite (undefined)
9152-
// loop.
9152+
// loop. Note that a stride of 0 is trivially no-self-wrap by definition.
91539153
if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
91549154
// TODO: We can peel off any functions which are invertible *in L*. Loop
91559155
// invariant terms are effectively constants for our purposes here.
91569156
auto *InnerLHS = LHS;
91579157
if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
91589158
InnerLHS = ZExt->getOperand();
9159-
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9160-
auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9161-
if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9162-
StrideC && StrideC->getAPInt().isPowerOf2()) {
9163-
auto Flags = AR->getNoWrapFlags();
9164-
Flags = setFlags(Flags, SCEV::FlagNW);
9165-
SmallVector<const SCEV*> Operands{AR->operands()};
9166-
Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9167-
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9168-
}
9159+
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9160+
AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9161+
isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true)) {
9162+
auto Flags = AR->getNoWrapFlags();
9163+
Flags = setFlags(Flags, SCEV::FlagNW);
9164+
SmallVector<const SCEV *> Operands{AR->operands()};
9165+
Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9166+
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
91699167
}
91709168
}
91719169

@@ -10845,6 +10843,23 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
1084510843
return getUnsignedRangeMin(S) != 0;
1084610844
}
1084710845

10846+
bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
10847+
auto NonRecursive = [this](const SCEV *S) {
10848+
if (auto *C = dyn_cast<SCEVConstant>(S))
10849+
return C->getAPInt().isPowerOf2();
10850+
// The vscale_range indicates vscale is a power-of-two.
10851+
return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
10852+
};
10853+
10854+
if (NonRecursive(S))
10855+
return true;
10856+
10857+
auto *Mul = dyn_cast<SCEVMulExpr>(S);
10858+
if (!Mul)
10859+
return false;
10860+
return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
10861+
}
10862+
1084810863
std::pair<const SCEV *, const SCEV *>
1084910864
ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
1085010865
// Compute SCEV on entry of loop L.
@@ -12775,8 +12790,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
1277512790
if (!isLoopInvariant(RHS, L))
1277612791
return false;
1277712792

12778-
auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12779-
if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12793+
if (!isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true))
1278012794
return false;
1278112795

1278212796
if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
@@ -13132,52 +13146,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
1313213146
// "(Start - End) + (Stride - 1)" has unsigned overflow.
1313313147
const SCEV *One = getOne(Stride->getType());
1313413148
bool MayAddOverflow = [&] {
13135-
if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13136-
if (StrideC->getAPInt().isPowerOf2()) {
13137-
// Suppose Stride is a power of two, and Start/End are unsigned
13138-
// integers. Let UMAX be the largest representable unsigned
13139-
// integer.
13140-
//
13141-
// By the preconditions of this function, we know
13142-
// "(Start + Stride * N) >= End", and this doesn't overflow.
13143-
// As a formula:
13144-
//
13145-
// End <= (Start + Stride * N) <= UMAX
13146-
//
13147-
// Subtracting Start from all the terms:
13148-
//
13149-
// End - Start <= Stride * N <= UMAX - Start
13150-
//
13151-
// Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13152-
//
13153-
// End - Start <= Stride * N <= UMAX
13154-
//
13155-
// Stride * N is a multiple of Stride. Therefore,
13156-
//
13157-
// End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13158-
//
13159-
// Since Stride is a power of two, UMAX + 1 is divisible by
13160-
// Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13161-
// write:
13162-
//
13163-
// End - Start <= Stride * N <= UMAX - Stride - 1
13164-
//
13165-
// Dropping the middle term:
13166-
//
13167-
// End - Start <= UMAX - Stride - 1
13168-
//
13169-
// Adding Stride - 1 to both sides:
13170-
//
13171-
// (End - Start) + (Stride - 1) <= UMAX
13172-
//
13173-
// In other words, the addition doesn't have unsigned overflow.
13174-
//
13175-
// A similar proof works if we treat Start/End as signed values.
13176-
// Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13177-
// to use signed max instead of unsigned max. Note that we're
13178-
// trying to prove a lack of unsigned overflow in either case.
13179-
return false;
13180-
}
13149+
if (isKnownToBeAPowerOfTwo(Stride)) {
13150+
// Suppose Stride is a power of two, and Start/End are unsigned
13151+
// integers. Let UMAX be the largest representable unsigned
13152+
// integer.
13153+
//
13154+
// By the preconditions of this function, we know
13155+
// "(Start + Stride * N) >= End", and this doesn't overflow.
13156+
// As a formula:
13157+
//
13158+
// End <= (Start + Stride * N) <= UMAX
13159+
//
13160+
// Subtracting Start from all the terms:
13161+
//
13162+
// End - Start <= Stride * N <= UMAX - Start
13163+
//
13164+
// Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13165+
//
13166+
// End - Start <= Stride * N <= UMAX
13167+
//
13168+
// Stride * N is a multiple of Stride. Therefore,
13169+
//
13170+
// End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13171+
//
13172+
// Since Stride is a power of two, UMAX + 1 is divisible by
13173+
// Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13174+
// write:
13175+
//
13176+
// End - Start <= Stride * N <= UMAX - Stride - 1
13177+
//
13178+
// Dropping the middle term:
13179+
//
13180+
// End - Start <= UMAX - Stride - 1
13181+
//
13182+
// Adding Stride - 1 to both sides:
13183+
//
13184+
// (End - Start) + (Stride - 1) <= UMAX
13185+
//
13186+
// In other words, the addition doesn't have unsigned overflow.
13187+
//
13188+
// A similar proof works if we treat Start/End as signed values.
13189+
// Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13190+
// to use signed max instead of unsigned max. Note that we're
13191+
// trying to prove a lack of unsigned overflow in either case.
13192+
return false;
1318113193
}
1318213194
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
1318313195
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End

llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -364,25 +364,25 @@ for.end: ; preds = %for.body, %entry
364364
}
365365

366366
; The next two cases check to see if we can infer the flags on the IV
367-
; of a countup loop using vscale strides.
368-
; TODO: We should be able to because vscale is a power of two and these
369-
; are finite loops by assumption.
367+
; of a countup loop using vscale strides. vscale is a power of two
368+
; and these are finite loops by assumption.
370369

371370
define void @vscale_slt_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_range(2,1024) {
372371
; CHECK-LABEL: 'vscale_slt_noflags'
373372
; CHECK-NEXT: Classifying expressions for: @vscale_slt_noflags
374373
; CHECK-NEXT: %vscale = call i32 @llvm.vscale.i32()
375374
; CHECK-NEXT: --> vscale U: [2,1025) S: [2,1025)
376375
; CHECK-NEXT: %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
377-
; CHECK-NEXT: --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
376+
; CHECK-NEXT: --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: (vscale * ((-1 + %n) /u vscale))<nuw> LoopDispositions: { %for.body: Computable }
378377
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
379-
; CHECK-NEXT: --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
378+
; CHECK-NEXT: --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((4 * vscale * ((-1 + %n) /u vscale)) + %A) LoopDispositions: { %for.body: Computable }
380379
; CHECK-NEXT: %add = add i32 %i.05, %vscale
381-
; CHECK-NEXT: --> {vscale,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
380+
; CHECK-NEXT: --> {vscale,+,vscale}<nw><%for.body> U: full-set S: full-set Exits: (vscale * (1 + ((-1 + %n) /u vscale))<nuw>) LoopDispositions: { %for.body: Computable }
382381
; CHECK-NEXT: Determining loop execution counts for: @vscale_slt_noflags
383-
; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count.
384-
; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
385-
; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
382+
; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u vscale)
383+
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 1073741822
384+
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u vscale)
385+
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
386386
;
387387
entry:
388388
%vscale = call i32 @llvm.vscale.i32()
@@ -411,15 +411,16 @@ define void @vscalex4_ult_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_
411411
; CHECK-NEXT: %VF = mul i32 %vscale, 4
412412
; CHECK-NEXT: --> (4 * vscale)<nuw><nsw> U: [8,4097) S: [8,4097)
413413
; CHECK-NEXT: %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
414-
; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
414+
; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (4 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) LoopDispositions: { %for.body: Computable }
415415
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
416-
; CHECK-NEXT: --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
416+
; CHECK-NEXT: --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((16 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) + %A) LoopDispositions: { %for.body: Computable }
417417
; CHECK-NEXT: %add = add i32 %i.05, %VF
418-
; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
418+
; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nw><%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (vscale * (4 + (4 * ((-1 + %n) /u (4 * vscale)<nuw><nsw>))<nuw><nsw>)<nuw>) LoopDispositions: { %for.body: Computable }
419419
; CHECK-NEXT: Determining loop execution counts for: @vscalex4_ult_noflags
420-
; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count.
421-
; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
422-
; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
420+
; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
421+
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 536870910
422+
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
423+
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
423424
;
424425
entry:
425426
%vscale = call i32 @llvm.vscale.i32()

0 commit comments

Comments
 (0)