-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SCEV] Handle more adds in computeConstantDifference() #101339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently it only deals with the case where we're subtracting adds with at most one non-constant operand. This patch extends it to cancel out common operands for the subtraction of arbitrary add expressions. The background here is that I want to replace a getMinusSCEV() call in LAA with computeConstantDifference(): https://github.com/llvm/llvm-project/blob/93fecc2577ece0329f3bbe2719bbc5b4b9b30010/llvm/lib/Analysis/LoopAccessAnalysis.cpp#L1602-L1603 This particular call is very expensive in some cases (e.g. lencod with LTO) and computeConstantDifference() could achieve this much more cheaply, because it does not need to construct new SCEV expressions. However, the current computeConstantDifference() implementation is too weak for this and misses many basic cases. This is a step towards making it more powerful while still keeping it pretty fast. Compile-time impact: http://llvm-compile-time-tracker.com/compare.php?from=5d833ee6acc85bf108a8787ba233e955728868ab&to=ef3b81a63874e8f05c5b627014b516d4c59388f4&stat=instructions:u
@llvm/pr-subscribers-llvm-analysis Author: Nikita Popov (nikic) ChangesCurrently it only deals with the case where we're subtracting adds with at most one non-constant operand. This patch extends it to cancel out common operands for the subtraction of arbitrary add expressions. The background here is that I want to replace a getMinusSCEV() call in LAA with computeConstantDifference(): llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp Lines 1602 to 1603 in 93fecc2
This particular call is very expensive in some cases (e.g. lencod with LTO) and computeConstantDifference() could achieve this much more cheaply, because it does not need to construct new SCEV expressions. However, the current computeConstantDifference() implementation is too weak for this and misses many basic cases. This is a step towards making it more powerful while still keeping it pretty fast. Compile-time impact: Full diff: https://github.com/llvm/llvm-project/pull/101339.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..bdd36e7d3154f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11923,8 +11923,9 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// fairly deep in the call stack (i.e. is called many times).
// X - X = 0.
+ unsigned BW = getTypeSizeInBits(More->getType());
if (More == Less)
- return APInt(getTypeSizeInBits(More->getType()), 0);
+ return APInt(BW, 0);
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
const auto *LAR = cast<SCEVAddRecExpr>(Less);
@@ -11947,33 +11948,31 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// fall through
}
- if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
- const auto &M = cast<SCEVConstant>(More)->getAPInt();
- const auto &L = cast<SCEVConstant>(Less)->getAPInt();
- return M - L;
- }
-
- SCEV::NoWrapFlags Flags;
- const SCEV *LLess = nullptr, *RLess = nullptr;
- const SCEV *LMore = nullptr, *RMore = nullptr;
- const SCEVConstant *C1 = nullptr, *C2 = nullptr;
- // Compare (X + C1) vs X.
- if (splitBinaryAdd(Less, LLess, RLess, Flags))
- if ((C1 = dyn_cast<SCEVConstant>(LLess)))
- if (RLess == More)
- return -(C1->getAPInt());
-
- // Compare X vs (X + C2).
- if (splitBinaryAdd(More, LMore, RMore, Flags))
- if ((C2 = dyn_cast<SCEVConstant>(LMore)))
- if (RMore == Less)
- return C2->getAPInt();
+ // Try to cancel out common factors in two add expressions.
+ SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+ APInt Diff(BW, 0);
+ auto Add = [&](const SCEV *S, int Mul) {
+ if (auto *C = dyn_cast<SCEVConstant>(S))
+ Diff += C->getAPInt() * Mul;
+ else
+ Multiplicity[S] += Mul;
+ };
+ auto Decompose = [&](const SCEV *S, int Mul) {
+ if (isa<SCEVAddExpr>(S)) {
+ for (const SCEV *Op : S->operands())
+ Add(Op, Mul);
+ } else
+ Add(S, Mul);
+ };
+ Decompose(More, 1);
+ Decompose(Less, -1);
- // Compare (X + C1) vs (X + C2).
- if (C1 && C2 && RLess == RMore)
- return C2->getAPInt() - C1->getAPInt();
+ // Check whether all the non-constants cancel out.
+ for (const auto [_, Mul] : Multiplicity)
+ if (Mul != 0)
+ return std::nullopt;
- return std::nullopt;
+ return Diff;
}
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 76e6095636305..3802ae4051f42 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1117,10 +1117,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
- "define void @foo(i32 %sz, i32 %pp) { "
+ "define void @foo(i32 %sz, i32 %pp, i32 %x) { "
"entry: "
" %v0 = add i32 %pp, 0 "
" %v3 = add i32 %pp, 3 "
+ " %vx = add i32 %pp, %x "
+ " %vx3 = add i32 %vx, 3 "
" br label %loop.body "
"loop.body: "
" %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
@@ -1141,6 +1143,9 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
+ auto *ScevVX = SE.getSCEV(getInstructionByName(F, "vx")); // (%pp + %x)
+ // (%pp + %x + 3)
+ auto *ScevVX3 = SE.getSCEV(getInstructionByName(F, "vx3"));
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
@@ -1162,6 +1167,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
EXPECT_EQ(diff(ScevV0, ScevV3), -3);
EXPECT_EQ(diff(ScevV0, ScevV0), 0);
EXPECT_EQ(diff(ScevV3, ScevV3), 0);
+ EXPECT_EQ(diff(ScevVX3, ScevVX), 3);
EXPECT_EQ(diff(ScevIV, ScevIV), 0);
EXPECT_EQ(diff(ScevXA, ScevXB), 0);
EXPECT_EQ(diff(ScevXA, ScevYY), -3);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Nice!
A couple of thoughts for follow on changes:
- I think you can handle mul -1, X by inverting multiplicity. Could possibly handle other mul-by-constants too.
- Once you did that, could we generalize this slightly to fold getMinusSCEV construction before resorting to the full add case? That wouldn't be all zero multiplicities, but would return the operands and their multiplicity.
Actually, it looks like we have something like this in CollectAddOperandsWithScales, common/share?
Yeah, support for multiplies is something I plan to add as a followup.
I think this is generally a good idea, but the details here are probably a bit tricky. One concern is flag preservation, and the other is that flattening SCEVs earlier can ultimately result in different expressions because our canonicalizations are inconsistent. Or at least I think I ran into this issue when trying something like that in the past.
This does conceptually the same thing, but the constraints are bit different. In particular, that one does want to create new SCEV nodes in some cases, which we don't want inside computeConstantDifference(). |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/30/builds/3329 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/154/builds/2274 Here is the relevant piece of the build log for the reference:
|
computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants). However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around). This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in llvm#101339, to make computeConstantDifference() powerful enough to replace existing uses of `dyn_cast<SCEVConstant>(getMinusSCEV())` with it. Though as the IR test diff shows, other callers may also benefit.
computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants). However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around). This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in llvm#101339, to make computeConstantDifference() powerful enough to replace existing uses of `dyn_cast<SCEVConstant>(getMinusSCEV())` with it. Though as the IR test diff shows, other callers may also benefit.
…101999) computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants). However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around). This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in #101339, to make computeConstantDifference() powerful enough to replace existing uses of `dyn_cast<SCEVConstant>(getMinusSCEV())` with it. Though as the IR test diff shows, other callers may also benefit.
Currently it only deals with the case where we're subtracting adds with at most one non-constant operand. This patch extends it to cancel out common operands for the subtraction of arbitrary add expressions.
The background here is that I want to replace a getMinusSCEV() call in LAA with computeConstantDifference():
llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp
Lines 1602 to 1603 in 93fecc2
This particular call is very expensive in some cases (e.g. lencod with LTO) and computeConstantDifference() could achieve this much more cheaply, because it does not need to construct new SCEV expressions.
However, the current computeConstantDifference() implementation is too weak for this and misses many basic cases. This is a step towards making it more powerful while still keeping it pretty fast.
Compile-time impact:
http://llvm-compile-time-tracker.com/compare.php?from=5d833ee6acc85bf108a8787ba233e955728868ab&to=ef3b81a63874e8f05c5b627014b516d4c59388f4&stat=instructions:u