-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SCEV] Use power of two facts involving vscale when inferring wrap flags #101380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops. The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason.
@llvm/pr-subscribers-llvm-analysis Author: Philip Reames (preames) ChangesSCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops. The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason. Full diff: https://github.com/llvm/llvm-project/pull/101380.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index d9bfca763819f..fbefa2bd074dd 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1028,6 +1028,9 @@ class ScalarEvolution {
/// Test if the given expression is known to be non-zero.
bool isKnownNonZero(const SCEV *S);
+ /// Test if the given expression is known to be a power of 2.
+ bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero = false);
+
/// Splits SCEV expression \p S into two SCEVs. One of them is obtained from
/// \p S by substitution of all AddRec sub-expression related to loop \p L
/// with initial value of that SCEV. The second is obtained from \p S by
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..159aa6e93a6ad 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9156,16 +9156,14 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
auto *InnerLHS = LHS;
if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
InnerLHS = ZExt->getOperand();
- if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
- auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
- if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
- StrideC && StrideC->getAPInt().isPowerOf2()) {
- auto Flags = AR->getNoWrapFlags();
- Flags = setFlags(Flags, SCEV::FlagNW);
- SmallVector<const SCEV*> Operands{AR->operands()};
- Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
- setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
- }
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
+ AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
+ isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this))) {
+ auto Flags = AR->getNoWrapFlags();
+ Flags = setFlags(Flags, SCEV::FlagNW);
+ SmallVector<const SCEV*> Operands{AR->operands()};
+ Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
}
}
@@ -10845,6 +10843,25 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
return getUnsignedRangeMin(S) != 0;
}
+bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
+ auto nonRecursive = [this](const SCEV *S) {
+ if (auto *C = dyn_cast<SCEVConstant>(S))
+ return C->getAPInt().isPowerOf2();
+ // The vscale_range indicates vscale is a power-of-two.
+ return S->getSCEVType() == scVScale && F.hasFnAttribute(Attribute::VScaleRange);;
+ };
+
+ if (nonRecursive(S))
+ return true;
+
+ auto *Mul = dyn_cast<SCEVMulExpr>(S);
+ if (!Mul || Mul->getNumOperands() != 2)
+ return false;
+ return nonRecursive(Mul->getOperand(0)) && nonRecursive(Mul->getOperand(1)) &&
+ (OrZero || isKnownNonZero(S));
+}
+
+
std::pair<const SCEV *, const SCEV *>
ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
// Compute SCEV on entry of loop L.
@@ -12775,8 +12792,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (!isLoopInvariant(RHS, L))
return false;
- auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
- if (!StrideC || !StrideC->getAPInt().isPowerOf2())
+ if (!isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this)))
return false;
if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
@@ -13132,52 +13148,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// "(Start - End) + (Stride - 1)" has unsigned overflow.
const SCEV *One = getOne(Stride->getType());
bool MayAddOverflow = [&] {
- if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
- if (StrideC->getAPInt().isPowerOf2()) {
- // Suppose Stride is a power of two, and Start/End are unsigned
- // integers. Let UMAX be the largest representable unsigned
- // integer.
- //
- // By the preconditions of this function, we know
- // "(Start + Stride * N) >= End", and this doesn't overflow.
- // As a formula:
- //
- // End <= (Start + Stride * N) <= UMAX
- //
- // Subtracting Start from all the terms:
- //
- // End - Start <= Stride * N <= UMAX - Start
- //
- // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
- //
- // End - Start <= Stride * N <= UMAX
- //
- // Stride * N is a multiple of Stride. Therefore,
- //
- // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
- //
- // Since Stride is a power of two, UMAX + 1 is divisible by
- // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
- // write:
- //
- // End - Start <= Stride * N <= UMAX - Stride - 1
- //
- // Dropping the middle term:
- //
- // End - Start <= UMAX - Stride - 1
- //
- // Adding Stride - 1 to both sides:
- //
- // (End - Start) + (Stride - 1) <= UMAX
- //
- // In other words, the addition doesn't have unsigned overflow.
- //
- // A similar proof works if we treat Start/End as signed values.
- // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
- // to use signed max instead of unsigned max. Note that we're
- // trying to prove a lack of unsigned overflow in either case.
- return false;
- }
+ if (isKnownToBeAPowerOfTwo(Stride)) {
+ // Suppose Stride is a power of two, and Start/End are unsigned
+ // integers. Let UMAX be the largest representable unsigned
+ // integer.
+ //
+ // By the preconditions of this function, we know
+ // "(Start + Stride * N) >= End", and this doesn't overflow.
+ // As a formula:
+ //
+ // End <= (Start + Stride * N) <= UMAX
+ //
+ // Subtracting Start from all the terms:
+ //
+ // End - Start <= Stride * N <= UMAX - Start
+ //
+ // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
+ //
+ // End - Start <= Stride * N <= UMAX
+ //
+ // Stride * N is a multiple of Stride. Therefore,
+ //
+ // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+ //
+ // Since Stride is a power of two, UMAX + 1 is divisible by
+ // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
+ // write:
+ //
+ // End - Start <= Stride * N <= UMAX - Stride - 1
+ //
+ // Dropping the middle term:
+ //
+ // End - Start <= UMAX - Stride - 1
+ //
+ // Adding Stride - 1 to both sides:
+ //
+ // (End - Start) + (Stride - 1) <= UMAX
+ //
+ // In other words, the addition doesn't have unsigned overflow.
+ //
+ // A similar proof works if we treat Start/End as signed values.
+ // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
+ // to use signed max instead of unsigned max. Note that we're
+ // trying to prove a lack of unsigned overflow in either case.
+ return false;
}
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
index 943389d07eb8b..50e6014734f31 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
@@ -374,15 +374,16 @@ define void @vscale_slt_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_ra
; CHECK-NEXT: %vscale = call i32 @llvm.vscale.i32()
; CHECK-NEXT: --> vscale U: [2,1025) S: [2,1025)
; CHECK-NEXT: %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT: --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: (vscale * ((-1 + %n) /u vscale))<nuw> LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT: --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((4 * vscale * ((-1 + %n) /u vscale)) + %A) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %add = add i32 %i.05, %vscale
-; CHECK-NEXT: --> {vscale,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {vscale,+,vscale}<nw><%for.body> U: full-set S: full-set Exits: (vscale * (1 + ((-1 + %n) /u vscale))<nuw>) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: Determining loop execution counts for: @vscale_slt_noflags
-; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 1073741822
+; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT: Loop %for.body: Trip multiple is 1
;
entry:
%vscale = call i32 @llvm.vscale.i32()
@@ -411,15 +412,16 @@ define void @vscalex4_ult_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_
; CHECK-NEXT: %VF = mul i32 %vscale, 4
; CHECK-NEXT: --> (4 * vscale)<nuw><nsw> U: [8,4097) S: [8,4097)
; CHECK-NEXT: %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (4 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT: --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((16 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) + %A) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %add = add i32 %i.05, %VF
-; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nw><%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (vscale * (4 + (4 * ((-1 + %n) /u (4 * vscale)<nuw><nsw>))<nuw><nsw>)<nuw>) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: Determining loop execution counts for: @vscalex4_ult_noflags
-; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 536870910
+; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT: Loop %for.body: Trip multiple is 1
;
entry:
%vscale = call i32 @llvm.vscale.i32()
|
You can test this locally with the following command:git-clang-format --diff 35a2e6d24bcb94720ec7b3aa00e58a1b7b837fbc 08ab563ba02f355a868a08ddcda9ba1c1e2edd86 --extensions h,cpp -- llvm/include/llvm/Analysis/ScalarEvolution.h llvm/lib/Analysis/ScalarEvolution.cpp View the diff from clang-format here.diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a5e2670aeb..264ac392b1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9161,7 +9161,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true)) {
auto Flags = AR->getNoWrapFlags();
Flags = setFlags(Flags, SCEV::FlagNW);
- SmallVector<const SCEV*> Operands{AR->operands()};
+ SmallVector<const SCEV *> Operands{AR->operands()};
Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
}
@@ -10857,11 +10857,9 @@ bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
auto *Mul = dyn_cast<SCEVMulExpr>(S);
if (!Mul)
return false;
- return all_of(Mul->operands(), NonRecursive) &&
- (OrZero || isKnownNonZero(S));
+ return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
}
-
std::pair<const SCEV *, const SCEV *>
ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
// Compute SCEV on entry of loop L.
|
; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>) | ||
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 536870910 | ||
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>) | ||
; CHECK-NEXT: Loop %for.body: Trip multiple is 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it expected that the vscale_countdown_ne case below isn't handled by this patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This change only includes the existing power-of-two cases. We need to handle negative power of two as well to infer no-self-wrap for these, which will be a separate patch.
Co-authored-by: Nikita Popov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but note the clang-format failure.
SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops.
The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason.