Skip to content

[SCEV] Use power of two facts involving vscale when inferring wrap flags #101380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 31, 2024

Conversation

preames
Copy link
Collaborator

@preames preames commented Jul 31, 2024

SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops.

The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason.

SCEV has logic for inferring wrap flags on AddRecs which are known to
control an exit based on whether the step is a power of two.  This logic
only considered constants, and thus did not trigger for steps such
as (4 x vscale) which are common in scalably vectorized loops.

The net effect is that we were very sensative to the preservation of
nsw/nuw flags on such IVs, and could not infer trip counts if they
got lost for any reason.
@preames preames requested a review from nikic as a code owner July 31, 2024 18:07
@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label Jul 31, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 31, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Philip Reames (preames)

Changes

SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops.

The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason.


Full diff: https://github.com/llvm/llvm-project/pull/101380.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+3)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+72-58)
  • (modified) llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll (+14-12)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index d9bfca763819f..fbefa2bd074dd 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1028,6 +1028,9 @@ class ScalarEvolution {
   /// Test if the given expression is known to be non-zero.
   bool isKnownNonZero(const SCEV *S);
 
+  /// Test if the given expression is known to be a power of 2.
+  bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero = false);
+
   /// Splits SCEV expression \p S into two SCEVs. One of them is obtained from
   /// \p S by substitution of all AddRec sub-expression related to loop \p L
   /// with initial value of that SCEV. The second is obtained from \p S by
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..159aa6e93a6ad 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9156,16 +9156,14 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     auto *InnerLHS = LHS;
     if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
       InnerLHS = ZExt->getOperand();
-    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
-      auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
-      if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
-          StrideC && StrideC->getAPInt().isPowerOf2()) {
-        auto Flags = AR->getNoWrapFlags();
-        Flags = setFlags(Flags, SCEV::FlagNW);
-        SmallVector<const SCEV*> Operands{AR->operands()};
-        Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
-        setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
-      }
+    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
+        AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
+        isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this))) {
+      auto Flags = AR->getNoWrapFlags();
+      Flags = setFlags(Flags, SCEV::FlagNW);
+      SmallVector<const SCEV*> Operands{AR->operands()};
+      Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+      setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
     }
   }
 
@@ -10845,6 +10843,25 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
   return getUnsignedRangeMin(S) != 0;
 }
 
+bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
+  auto nonRecursive = [this](const SCEV *S) {
+    if (auto *C = dyn_cast<SCEVConstant>(S))
+      return C->getAPInt().isPowerOf2();
+    // The vscale_range indicates vscale is a power-of-two.
+    return S->getSCEVType() == scVScale && F.hasFnAttribute(Attribute::VScaleRange);;
+  };
+
+  if (nonRecursive(S))
+    return true;
+
+  auto *Mul = dyn_cast<SCEVMulExpr>(S);
+  if (!Mul || Mul->getNumOperands() != 2)
+    return false;
+  return nonRecursive(Mul->getOperand(0)) && nonRecursive(Mul->getOperand(1)) &&
+    (OrZero || isKnownNonZero(S));
+}
+
+
 std::pair<const SCEV *, const SCEV *>
 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
   // Compute SCEV on entry of loop L.
@@ -12775,8 +12792,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     if (!isLoopInvariant(RHS, L))
       return false;
 
-    auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
-    if (!StrideC || !StrideC->getAPInt().isPowerOf2())
+    if (!isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this)))
       return false;
 
     if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
@@ -13132,52 +13148,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       // "(Start - End) + (Stride - 1)" has unsigned overflow.
       const SCEV *One = getOne(Stride->getType());
       bool MayAddOverflow = [&] {
-        if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
-          if (StrideC->getAPInt().isPowerOf2()) {
-            // Suppose Stride is a power of two, and Start/End are unsigned
-            // integers.  Let UMAX be the largest representable unsigned
-            // integer.
-            //
-            // By the preconditions of this function, we know
-            // "(Start + Stride * N) >= End", and this doesn't overflow.
-            // As a formula:
-            //
-            //   End <= (Start + Stride * N) <= UMAX
-            //
-            // Subtracting Start from all the terms:
-            //
-            //   End - Start <= Stride * N <= UMAX - Start
-            //
-            // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
-            //
-            //   End - Start <= Stride * N <= UMAX
-            //
-            // Stride * N is a multiple of Stride. Therefore,
-            //
-            //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
-            //
-            // Since Stride is a power of two, UMAX + 1 is divisible by
-            // Stride. Therefore, UMAX mod Stride == Stride - 1.  So we can
-            // write:
-            //
-            //   End - Start <= Stride * N <= UMAX - Stride - 1
-            //
-            // Dropping the middle term:
-            //
-            //   End - Start <= UMAX - Stride - 1
-            //
-            // Adding Stride - 1 to both sides:
-            //
-            //   (End - Start) + (Stride - 1) <= UMAX
-            //
-            // In other words, the addition doesn't have unsigned overflow.
-            //
-            // A similar proof works if we treat Start/End as signed values.
-            // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
-            // to use signed max instead of unsigned max. Note that we're
-            // trying to prove a lack of unsigned overflow in either case.
-            return false;
-          }
+        if (isKnownToBeAPowerOfTwo(Stride)) {
+          // Suppose Stride is a power of two, and Start/End are unsigned
+          // integers.  Let UMAX be the largest representable unsigned
+          // integer.
+          //
+          // By the preconditions of this function, we know
+          // "(Start + Stride * N) >= End", and this doesn't overflow.
+          // As a formula:
+          //
+          //   End <= (Start + Stride * N) <= UMAX
+          //
+          // Subtracting Start from all the terms:
+          //
+          //   End - Start <= Stride * N <= UMAX - Start
+          //
+          // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
+          //
+          //   End - Start <= Stride * N <= UMAX
+          //
+          // Stride * N is a multiple of Stride. Therefore,
+          //
+          //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+          //
+          // Since Stride is a power of two, UMAX + 1 is divisible by
+          // Stride. Therefore, UMAX mod Stride == Stride - 1.  So we can
+          // write:
+          //
+          //   End - Start <= Stride * N <= UMAX - Stride - 1
+          //
+          // Dropping the middle term:
+          //
+          //   End - Start <= UMAX - Stride - 1
+          //
+          // Adding Stride - 1 to both sides:
+          //
+          //   (End - Start) + (Stride - 1) <= UMAX
+          //
+          // In other words, the addition doesn't have unsigned overflow.
+          //
+          // A similar proof works if we treat Start/End as signed values.
+          // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
+          // to use signed max instead of unsigned max. Note that we're
+          // trying to prove a lack of unsigned overflow in either case.
+          return false;
         }
         if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
           // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
index 943389d07eb8b..50e6014734f31 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
@@ -374,15 +374,16 @@ define void @vscale_slt_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_ra
 ; CHECK-NEXT:    %vscale = call i32 @llvm.vscale.i32()
 ; CHECK-NEXT:    --> vscale U: [2,1025) S: [2,1025)
 ; CHECK-NEXT:    %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT:    --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: (vscale * ((-1 + %n) /u vscale))<nuw> LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT:    --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((4 * vscale * ((-1 + %n) /u vscale)) + %A) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %add = add i32 %i.05, %vscale
-; CHECK-NEXT:    --> {vscale,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {vscale,+,vscale}<nw><%for.body> U: full-set S: full-set Exits: (vscale * (1 + ((-1 + %n) /u vscale))<nuw>) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @vscale_slt_noflags
-; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 1073741822
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
 ;
 entry:
   %vscale = call i32 @llvm.vscale.i32()
@@ -411,15 +412,16 @@ define void @vscalex4_ult_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_
 ; CHECK-NEXT:    %VF = mul i32 %vscale, 4
 ; CHECK-NEXT:    --> (4 * vscale)<nuw><nsw> U: [8,4097) S: [8,4097)
 ; CHECK-NEXT:    %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT:    --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (4 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT:    --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((16 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) + %A) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %add = add i32 %i.05, %VF
-; CHECK-NEXT:    --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nw><%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (vscale * (4 + (4 * ((-1 + %n) /u (4 * vscale)<nuw><nsw>))<nuw><nsw>)<nuw>) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @vscalex4_ult_noflags
-; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 536870910
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
 ;
 entry:
   %vscale = call i32 @llvm.vscale.i32()

Copy link

github-actions bot commented Jul 31, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 35a2e6d24bcb94720ec7b3aa00e58a1b7b837fbc 08ab563ba02f355a868a08ddcda9ba1c1e2edd86 --extensions h,cpp -- llvm/include/llvm/Analysis/ScalarEvolution.h llvm/lib/Analysis/ScalarEvolution.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a5e2670aeb..264ac392b1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9161,7 +9161,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
         isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true)) {
       auto Flags = AR->getNoWrapFlags();
       Flags = setFlags(Flags, SCEV::FlagNW);
-      SmallVector<const SCEV*> Operands{AR->operands()};
+      SmallVector<const SCEV *> Operands{AR->operands()};
       Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
       setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
     }
@@ -10857,11 +10857,9 @@ bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
   auto *Mul = dyn_cast<SCEVMulExpr>(S);
   if (!Mul)
     return false;
-  return all_of(Mul->operands(), NonRecursive) &&
-    (OrZero || isKnownNonZero(S));
+  return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
 }
 
-
 std::pair<const SCEV *, const SCEV *>
 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
   // Compute SCEV on entry of loop L.

; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 536870910
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it expected that the vscale_countdown_ne case below isn't handled by this patch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This change only includes the existing power-of-two cases. We need to handle negative power of two as well to infer no-self-wrap for these, which will be a separate patch.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but note the clang-format failure.

@preames preames merged commit 7583c48 into llvm:main Jul 31, 2024
4 of 6 checks passed
@preames preames deleted the pr-scev-vscale-is-power-of-two branch July 31, 2024 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants