Skip to content

[SCEV] Handle more add/addrec mixes in computeConstantDifference() #101999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Aug 5, 2024

computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants).

However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around).

This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in #101339, to make
computeConstantDifference() powerful enough to replace existing uses of dyn_cast<SCEVConstant>(getMinusSCEV()) with it. Though as the IR test diff shows, other callers may also benefit.

There is no significant compile-time impact: http://llvm-compile-time-tracker.com/compare.php?from=1fa7f05b709748e8a36936cbb5508070c8214354&to=12530f93c3864dc21c6c273d3a0f08ee59c6a406&stat=instructions%3Au

@nikic nikic requested review from fhahn and preames August 5, 2024 14:57
@llvmbot llvmbot added llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Aug 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants).

However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around).

This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in #101339, to make
computeConstantDifference() powerful enough to replace existing uses of dyn_cast&lt;SCEVConstant&gt;(getMinusSCEV()) with it. Though as the IR test diff shows, other callers may also benefit.


Full diff: https://github.com/llvm/llvm-project/pull/101999.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+71-44)
  • (modified) llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll (+23-37)
  • (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+2-2)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a5ebd5c554c3d..f8ef98a92efa1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11933,56 +11933,83 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
   // We avoid subtracting expressions here because this function is usually
   // fairly deep in the call stack (i.e. is called many times).
 
-  // X - X = 0.
   unsigned BW = getTypeSizeInBits(More->getType());
-  if (More == Less)
-    return APInt(BW, 0);
-
-  if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
-    const auto *LAR = cast<SCEVAddRecExpr>(Less);
-    const auto *MAR = cast<SCEVAddRecExpr>(More);
-
-    if (LAR->getLoop() != MAR->getLoop())
-      return std::nullopt;
-
-    // We look at affine expressions only; not for correctness but to keep
-    // getStepRecurrence cheap.
-    if (!LAR->isAffine() || !MAR->isAffine())
-      return std::nullopt;
+  APInt Diff(BW, 0);
+  // Try various simplifications to reduce the difference to a constant. Limit
+  // the number of allowed simplifications to keep compile-time low.
+  for (unsigned I = 0; I < 4; ++I) {
+    if (More == Less)
+      return Diff;
+
+    // Reduce addrecs with identical steps to their start value.
+    if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
+      const auto *LAR = cast<SCEVAddRecExpr>(Less);
+      const auto *MAR = cast<SCEVAddRecExpr>(More);
+
+      if (LAR->getLoop() != MAR->getLoop())
+        return std::nullopt;
+
+      // We look at affine expressions only; not for correctness but to keep
+      // getStepRecurrence cheap.
+      if (!LAR->isAffine() || !MAR->isAffine())
+        return std::nullopt;
+
+      if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+        return std::nullopt;
+
+      Less = LAR->getStart();
+      More = MAR->getStart();
+      continue;
+    }
 
-    if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+    // Try to cancel out common factors in two add expressions.
+    SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+    auto Add = [&](const SCEV *S, int Mul) {
+      if (auto *C = dyn_cast<SCEVConstant>(S))
+        Diff += C->getAPInt() * Mul;
+      else
+        Multiplicity[S] += Mul;
+    };
+    auto Decompose = [&](const SCEV *S, int Mul) {
+      if (isa<SCEVAddExpr>(S)) {
+        for (const SCEV *Op : S->operands())
+          Add(Op, Mul);
+      } else
+        Add(S, Mul);
+    };
+    Decompose(More, 1);
+    Decompose(Less, -1);
+
+    // Check whether all the non-constants cancel out, or reduce to new
+    // More/Less values.
+    const SCEV *NewMore = nullptr, *NewLess = nullptr;
+    for (const auto [S, Mul] : Multiplicity) {
+      if (Mul == 0)
+        continue;
+      if (Mul == 1) {
+        if (NewMore)
+          return std::nullopt;
+        NewMore = S;
+      } else if (Mul == -1) {
+        if (NewLess)
+          return std::nullopt;
+        NewLess = S;
+      } else
+        return std::nullopt;
+    }
+
+    // Values stayed the same, no point in trying further.
+    if (NewMore == More || NewLess == Less)
       return std::nullopt;
 
-    Less = LAR->getStart();
-    More = MAR->getStart();
-
-    // fall through
+    More = NewMore;
+    Less = NewLess;
+    if (!More || !Less)
+      break;
   }
 
-  // Try to cancel out common factors in two add expressions.
-  SmallDenseMap<const SCEV *, int, 8> Multiplicity;
-  APInt Diff(BW, 0);
-  auto Add = [&](const SCEV *S, int Mul) {
-    if (auto *C = dyn_cast<SCEVConstant>(S))
-      Diff += C->getAPInt() * Mul;
-    else
-      Multiplicity[S] += Mul;
-  };
-  auto Decompose = [&](const SCEV *S, int Mul) {
-    if (isa<SCEVAddExpr>(S)) {
-      for (const SCEV *Op : S->operands())
-        Add(Op, Mul);
-    } else
-      Add(S, Mul);
-  };
-  Decompose(More, 1);
-  Decompose(Less, -1);
-
-  // Check whether all the non-constants cancel out.
-  for (const auto &[_, Mul] : Multiplicity)
-    if (Mul != 0)
-      return std::nullopt;
-
+  if (More || Less)
+    return std::nullopt; // Did not reduce to constant.
   return Diff;
 }
 
diff --git a/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll b/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
index db5a7105fd8c4..f55e37c777260 100644
--- a/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
@@ -29,13 +29,13 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 2
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_MEMCHECK:%.*]]
 ; CHECK:       vector.memcheck:
-; CHECK-NEXT:    [[UGLYGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
-; CHECK-NEXT:    [[UGLYGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
+; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
+; CHECK-NEXT:    [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
 ; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[N]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[TMP1]], 4
-; CHECK-NEXT:    [[UGLYGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
-; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[UGLYGEP6]]
-; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[UGLYGEP5]], [[UGLYGEP]]
+; CHECK-NEXT:    [[SCEVGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
+; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[SCEVGEP6]]
+; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[SCEVGEP5]], [[SCEVGEP]]
 ; CHECK-NEXT:    [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
 ; CHECK-NEXT:    br i1 [[FOUND_CONFLICT]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
@@ -48,10 +48,10 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[TMP4:%.*]] = add nuw nsw i64 [[TMP3]], 1
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[L_1]], i64 [[TMP4]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope !0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope [[META0:![0-9]+]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[L_2]], i64 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x i16> [[WIDE_LOAD]], i32 1
-; CHECK-NEXT:    store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope !3, !noalias !0
+; CHECK-NEXT:    store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope [[META3:![0-9]+]], !noalias [[META0]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
 ; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
@@ -74,7 +74,7 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[LOOP_L_1:%.*]] = load i16, ptr [[GEP_1]], align 2
 ; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i16, ptr [[L_2_LCSSA]], i64 0
 ; CHECK-NEXT:    store i16 [[LOOP_L_1]], ptr [[GEP_2]], align 2
-; CHECK-NEXT:    br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP7:![0-9]+]]
+; CHECK-NEXT:    br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       exit.loopexit:
 ; CHECK-NEXT:    br label [[EXIT:%.*]]
 ; CHECK:       exit.loopexit1:
@@ -138,31 +138,17 @@ define void @test2(ptr %dst) {
 ; CHECK-NEXT:    [[INDVAR_NEXT]] = add i32 [[INDVAR]], 1
 ; CHECK-NEXT:    br i1 [[C_1]], label [[LOOP_2]], label [[LOOP_3_PH:%.*]]
 ; CHECK:       loop.3.ph:
-; CHECK-NEXT:    [[INDVAR_LCSSA1:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[INDVAR_LCSSA:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[IV_1_LCSSA:%.*]] = phi i64 [ [[IV_1]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = and i64 [[IV_1_LCSSA]], 4294967295
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA1]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], 1000
-; CHECK-NEXT:    [[SMIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
-; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN2]]
+; CHECK-NEXT:    [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
+; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
 ; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP5]], 2
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
-; CHECK:       vector.scevcheck:
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
-; CHECK-NEXT:    [[TMP7:%.*]] = add i32 [[TMP6]], 1000
-; CHECK-NEXT:    [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP7]], i32 1)
-; CHECK-NEXT:    [[TMP8:%.*]] = sub i32 [[TMP7]], [[SMIN]]
-; CHECK-NEXT:    [[TMP9:%.*]] = add i32 [[TMP6]], 999
-; CHECK-NEXT:    [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 1, i32 [[TMP8]])
-; CHECK-NEXT:    [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
-; CHECK-NEXT:    [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
-; CHECK-NEXT:    [[TMP10:%.*]] = sub i32 [[TMP9]], [[MUL_RESULT]]
-; CHECK-NEXT:    [[TMP11:%.*]] = icmp ugt i32 [[TMP10]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = or i1 [[TMP11]], [[MUL_OVERFLOW]]
-; CHECK-NEXT:    br i1 [[TMP12]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP5]], 2
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP5]], [[N_MOD_VF]]
@@ -171,21 +157,21 @@ define void @test2(ptr %dst) {
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = sub i64 [[TMP0]], [[INDEX]]
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[OFFSET_IDX]], 0
-; CHECK-NEXT:    [[TMP14:%.*]] = add nsw i64 [[TMP13]], -1
-; CHECK-NEXT:    [[TMP15:%.*]] = and i64 [[TMP14]], 4294967295
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP15]]
-; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[TMP16]], i32 0
-; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 -1
-; CHECK-NEXT:    store <2 x i32> zeroinitializer, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[OFFSET_IDX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = add nsw i64 [[TMP6]], -1
+; CHECK-NEXT:    [[TMP8:%.*]] = and i64 [[TMP7]], 4294967295
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP9]], i32 0
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i32, ptr [[TMP10]], i32 -1
+; CHECK-NEXT:    store <2 x i32> zeroinitializer, ptr [[TMP11]], align 4
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP5]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[LOOP_1_LATCH:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ], [ [[TMP0]], [[VECTOR_SCEVCHECK]] ]
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ]
 ; CHECK-NEXT:    br label [[LOOP_3:%.*]]
 ; CHECK:       loop.3:
 ; CHECK-NEXT:    [[IV_2:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_2_NEXT:%.*]], [[LOOP_3]] ]
@@ -195,7 +181,7 @@ define void @test2(ptr %dst) {
 ; CHECK-NEXT:    store i32 0, ptr [[GEP_DST]], align 4
 ; CHECK-NEXT:    [[IV_2_TRUNC:%.*]] = trunc i64 [[IV_2]] to i32
 ; CHECK-NEXT:    [[EC:%.*]] = icmp sgt i32 [[IV_2_TRUNC]], 1
-; CHECK-NEXT:    br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP9:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP10:![0-9]+]]
 ; CHECK:       loop.1.latch:
 ; CHECK-NEXT:    [[C_2:%.*]] = call i1 @cond()
 ; CHECK-NEXT:    br i1 [[C_2]], label [[EXIT:%.*]], label [[LOOP_1_HEADER]]
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 64a7503d30eed..571051f6ab55b 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1199,8 +1199,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     EXPECT_EQ(diff(ScevIV, ScevIVNext), -1);
     EXPECT_EQ(diff(ScevIVNext, ScevIV), 1);
     EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
-    EXPECT_EQ(diff(ScevIV2P3, ScevIV2), std::nullopt); // TODO
-    EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), std::nullopt); // TODO
+    EXPECT_EQ(diff(ScevIV2P3, ScevIV2), 3);
+    EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), -3);
     EXPECT_EQ(diff(ScevV0, ScevIV), std::nullopt);
     EXPECT_EQ(diff(ScevIVNext, ScevV3), std::nullopt);
     EXPECT_EQ(diff(ScevYY, ScevV3), std::nullopt);

@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants).

However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around).

This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in #101339, to make
computeConstantDifference() powerful enough to replace existing uses of dyn_cast&lt;SCEVConstant&gt;(getMinusSCEV()) with it. Though as the IR test diff shows, other callers may also benefit.


Full diff: https://github.com/llvm/llvm-project/pull/101999.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+71-44)
  • (modified) llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll (+23-37)
  • (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+2-2)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a5ebd5c554c3d..f8ef98a92efa1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11933,56 +11933,83 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
   // We avoid subtracting expressions here because this function is usually
   // fairly deep in the call stack (i.e. is called many times).
 
-  // X - X = 0.
   unsigned BW = getTypeSizeInBits(More->getType());
-  if (More == Less)
-    return APInt(BW, 0);
-
-  if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
-    const auto *LAR = cast<SCEVAddRecExpr>(Less);
-    const auto *MAR = cast<SCEVAddRecExpr>(More);
-
-    if (LAR->getLoop() != MAR->getLoop())
-      return std::nullopt;
-
-    // We look at affine expressions only; not for correctness but to keep
-    // getStepRecurrence cheap.
-    if (!LAR->isAffine() || !MAR->isAffine())
-      return std::nullopt;
+  APInt Diff(BW, 0);
+  // Try various simplifications to reduce the difference to a constant. Limit
+  // the number of allowed simplifications to keep compile-time low.
+  for (unsigned I = 0; I < 4; ++I) {
+    if (More == Less)
+      return Diff;
+
+    // Reduce addrecs with identical steps to their start value.
+    if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
+      const auto *LAR = cast<SCEVAddRecExpr>(Less);
+      const auto *MAR = cast<SCEVAddRecExpr>(More);
+
+      if (LAR->getLoop() != MAR->getLoop())
+        return std::nullopt;
+
+      // We look at affine expressions only; not for correctness but to keep
+      // getStepRecurrence cheap.
+      if (!LAR->isAffine() || !MAR->isAffine())
+        return std::nullopt;
+
+      if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+        return std::nullopt;
+
+      Less = LAR->getStart();
+      More = MAR->getStart();
+      continue;
+    }
 
-    if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
+    // Try to cancel out common factors in two add expressions.
+    SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+    auto Add = [&](const SCEV *S, int Mul) {
+      if (auto *C = dyn_cast<SCEVConstant>(S))
+        Diff += C->getAPInt() * Mul;
+      else
+        Multiplicity[S] += Mul;
+    };
+    auto Decompose = [&](const SCEV *S, int Mul) {
+      if (isa<SCEVAddExpr>(S)) {
+        for (const SCEV *Op : S->operands())
+          Add(Op, Mul);
+      } else
+        Add(S, Mul);
+    };
+    Decompose(More, 1);
+    Decompose(Less, -1);
+
+    // Check whether all the non-constants cancel out, or reduce to new
+    // More/Less values.
+    const SCEV *NewMore = nullptr, *NewLess = nullptr;
+    for (const auto [S, Mul] : Multiplicity) {
+      if (Mul == 0)
+        continue;
+      if (Mul == 1) {
+        if (NewMore)
+          return std::nullopt;
+        NewMore = S;
+      } else if (Mul == -1) {
+        if (NewLess)
+          return std::nullopt;
+        NewLess = S;
+      } else
+        return std::nullopt;
+    }
+
+    // Values stayed the same, no point in trying further.
+    if (NewMore == More || NewLess == Less)
       return std::nullopt;
 
-    Less = LAR->getStart();
-    More = MAR->getStart();
-
-    // fall through
+    More = NewMore;
+    Less = NewLess;
+    if (!More || !Less)
+      break;
   }
 
-  // Try to cancel out common factors in two add expressions.
-  SmallDenseMap<const SCEV *, int, 8> Multiplicity;
-  APInt Diff(BW, 0);
-  auto Add = [&](const SCEV *S, int Mul) {
-    if (auto *C = dyn_cast<SCEVConstant>(S))
-      Diff += C->getAPInt() * Mul;
-    else
-      Multiplicity[S] += Mul;
-  };
-  auto Decompose = [&](const SCEV *S, int Mul) {
-    if (isa<SCEVAddExpr>(S)) {
-      for (const SCEV *Op : S->operands())
-        Add(Op, Mul);
-    } else
-      Add(S, Mul);
-  };
-  Decompose(More, 1);
-  Decompose(Less, -1);
-
-  // Check whether all the non-constants cancel out.
-  for (const auto &[_, Mul] : Multiplicity)
-    if (Mul != 0)
-      return std::nullopt;
-
+  if (More || Less)
+    return std::nullopt; // Did not reduce to constant.
   return Diff;
 }
 
diff --git a/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll b/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
index db5a7105fd8c4..f55e37c777260 100644
--- a/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll
@@ -29,13 +29,13 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 2
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_MEMCHECK:%.*]]
 ; CHECK:       vector.memcheck:
-; CHECK-NEXT:    [[UGLYGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
-; CHECK-NEXT:    [[UGLYGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
+; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
+; CHECK-NEXT:    [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
 ; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[N]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[TMP1]], 4
-; CHECK-NEXT:    [[UGLYGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
-; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[UGLYGEP6]]
-; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[UGLYGEP5]], [[UGLYGEP]]
+; CHECK-NEXT:    [[SCEVGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
+; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[SCEVGEP6]]
+; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[SCEVGEP5]], [[SCEVGEP]]
 ; CHECK-NEXT:    [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
 ; CHECK-NEXT:    br i1 [[FOUND_CONFLICT]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
@@ -48,10 +48,10 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[TMP4:%.*]] = add nuw nsw i64 [[TMP3]], 1
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[L_1]], i64 [[TMP4]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i32 0
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope !0
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope [[META0:![0-9]+]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[L_2]], i64 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x i16> [[WIDE_LOAD]], i32 1
-; CHECK-NEXT:    store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope !3, !noalias !0
+; CHECK-NEXT:    store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope [[META3:![0-9]+]], !noalias [[META0]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
 ; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
@@ -74,7 +74,7 @@ define i16 @test(ptr %arg, i64 %N) {
 ; CHECK-NEXT:    [[LOOP_L_1:%.*]] = load i16, ptr [[GEP_1]], align 2
 ; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i16, ptr [[L_2_LCSSA]], i64 0
 ; CHECK-NEXT:    store i16 [[LOOP_L_1]], ptr [[GEP_2]], align 2
-; CHECK-NEXT:    br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP7:![0-9]+]]
+; CHECK-NEXT:    br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       exit.loopexit:
 ; CHECK-NEXT:    br label [[EXIT:%.*]]
 ; CHECK:       exit.loopexit1:
@@ -138,31 +138,17 @@ define void @test2(ptr %dst) {
 ; CHECK-NEXT:    [[INDVAR_NEXT]] = add i32 [[INDVAR]], 1
 ; CHECK-NEXT:    br i1 [[C_1]], label [[LOOP_2]], label [[LOOP_3_PH:%.*]]
 ; CHECK:       loop.3.ph:
-; CHECK-NEXT:    [[INDVAR_LCSSA1:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[INDVAR_LCSSA:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[IV_1_LCSSA:%.*]] = phi i64 [ [[IV_1]], [[LOOP_2]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = and i64 [[IV_1_LCSSA]], 4294967295
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA1]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], 1000
-; CHECK-NEXT:    [[SMIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
-; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN2]]
+; CHECK-NEXT:    [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
+; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
 ; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP5]], 2
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
-; CHECK:       vector.scevcheck:
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
-; CHECK-NEXT:    [[TMP7:%.*]] = add i32 [[TMP6]], 1000
-; CHECK-NEXT:    [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP7]], i32 1)
-; CHECK-NEXT:    [[TMP8:%.*]] = sub i32 [[TMP7]], [[SMIN]]
-; CHECK-NEXT:    [[TMP9:%.*]] = add i32 [[TMP6]], 999
-; CHECK-NEXT:    [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 1, i32 [[TMP8]])
-; CHECK-NEXT:    [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
-; CHECK-NEXT:    [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
-; CHECK-NEXT:    [[TMP10:%.*]] = sub i32 [[TMP9]], [[MUL_RESULT]]
-; CHECK-NEXT:    [[TMP11:%.*]] = icmp ugt i32 [[TMP10]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = or i1 [[TMP11]], [[MUL_OVERFLOW]]
-; CHECK-NEXT:    br i1 [[TMP12]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP5]], 2
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP5]], [[N_MOD_VF]]
@@ -171,21 +157,21 @@ define void @test2(ptr %dst) {
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = sub i64 [[TMP0]], [[INDEX]]
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[OFFSET_IDX]], 0
-; CHECK-NEXT:    [[TMP14:%.*]] = add nsw i64 [[TMP13]], -1
-; CHECK-NEXT:    [[TMP15:%.*]] = and i64 [[TMP14]], 4294967295
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP15]]
-; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[TMP16]], i32 0
-; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 -1
-; CHECK-NEXT:    store <2 x i32> zeroinitializer, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[OFFSET_IDX]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = add nsw i64 [[TMP6]], -1
+; CHECK-NEXT:    [[TMP8:%.*]] = and i64 [[TMP7]], 4294967295
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP9]], i32 0
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i32, ptr [[TMP10]], i32 -1
+; CHECK-NEXT:    store <2 x i32> zeroinitializer, ptr [[TMP11]], align 4
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP5]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[LOOP_1_LATCH:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ], [ [[TMP0]], [[VECTOR_SCEVCHECK]] ]
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ]
 ; CHECK-NEXT:    br label [[LOOP_3:%.*]]
 ; CHECK:       loop.3:
 ; CHECK-NEXT:    [[IV_2:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_2_NEXT:%.*]], [[LOOP_3]] ]
@@ -195,7 +181,7 @@ define void @test2(ptr %dst) {
 ; CHECK-NEXT:    store i32 0, ptr [[GEP_DST]], align 4
 ; CHECK-NEXT:    [[IV_2_TRUNC:%.*]] = trunc i64 [[IV_2]] to i32
 ; CHECK-NEXT:    [[EC:%.*]] = icmp sgt i32 [[IV_2_TRUNC]], 1
-; CHECK-NEXT:    br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP9:![0-9]+]]
+; CHECK-NEXT:    br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP10:![0-9]+]]
 ; CHECK:       loop.1.latch:
 ; CHECK-NEXT:    [[C_2:%.*]] = call i1 @cond()
 ; CHECK-NEXT:    br i1 [[C_2]], label [[EXIT:%.*]], label [[LOOP_1_HEADER]]
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 64a7503d30eed..571051f6ab55b 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1199,8 +1199,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     EXPECT_EQ(diff(ScevIV, ScevIVNext), -1);
     EXPECT_EQ(diff(ScevIVNext, ScevIV), 1);
     EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
-    EXPECT_EQ(diff(ScevIV2P3, ScevIV2), std::nullopt); // TODO
-    EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), std::nullopt); // TODO
+    EXPECT_EQ(diff(ScevIV2P3, ScevIV2), 3);
+    EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), -3);
     EXPECT_EQ(diff(ScevV0, ScevIV), std::nullopt);
     EXPECT_EQ(diff(ScevIVNext, ScevV3), std::nullopt);
     EXPECT_EQ(diff(ScevYY, ScevV3), std::nullopt);

@nikic nikic changed the title [SCEV] Handle more add/addrec mixed in computeConstantDifference() [SCEV] Handle more add/addrec mixes in computeConstantDifference() Aug 5, 2024
computeConstantDifference() can currently look through addrecs with
identical steps, and then through adds with identical operands
(apart from constants).

However, it fails to handle minor variations, such as two nested
add recs, or an outer add with an inner addrec (rather than the
other way around).

This patch supports these cases by adding a loop over the
simplifications, limited to a small number of iterations. The
motivation is the same as in llvm#101339, to make
computeConstantDifference() powerful enough to replace existing
uses of `dyn_cast<SCEVConstant>(getMinusSCEV())` with it. Though
as the IR test diff shows, other callers may also benefit.
@nikic nikic force-pushed the scev-const-diff-2 branch from 885fe7f to d9c41bf Compare August 12, 2024 13:40
Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/optional comment

// fall through
More = NewMore;
Less = NewLess;
if (!More || !Less)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stylistic choice here to break instead of just returning for the two sub-cases here seems a bit odd. It would seem to be a simpler invariant to always have More and Less valid values on the backedge.

@nikic nikic merged commit 306b9c7 into llvm:main Aug 13, 2024
6 of 8 checks passed
@nikic nikic deleted the scev-const-diff-2 branch August 13, 2024 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants