Skip to content

Commit 79af689

Browse files
authored
[SCEV] Handle more adds in computeConstantDifference() (#101339)
Currently it only deals with the case where we're subtracting adds with at most one non-constant operand. This patch extends it to cancel out common operands for the subtraction of arbitrary add expressions. The background here is that I want to replace a getMinusSCEV() call in LAA with computeConstantDifference(): https://github.com/llvm/llvm-project/blob/93fecc2577ece0329f3bbe2719bbc5b4b9b30010/llvm/lib/Analysis/LoopAccessAnalysis.cpp#L1602-L1603 This particular call is very expensive in some cases (e.g. lencod with LTO) and computeConstantDifference() could achieve this much more cheaply, because it does not need to construct new SCEV expressions. However, the current computeConstantDifference() implementation is too weak for this and misses many basic cases. This is a step towards making it more powerful while still keeping it pretty fast.
1 parent aa0a33b commit 79af689

File tree

2 files changed

+32
-27
lines changed

2 files changed

+32
-27
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11934,8 +11934,9 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1193411934
// fairly deep in the call stack (i.e. is called many times).
1193511935

1193611936
// X - X = 0.
11937+
unsigned BW = getTypeSizeInBits(More->getType());
1193711938
if (More == Less)
11938-
return APInt(getTypeSizeInBits(More->getType()), 0);
11939+
return APInt(BW, 0);
1193911940

1194011941
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
1194111942
const auto *LAR = cast<SCEVAddRecExpr>(Less);
@@ -11958,33 +11959,31 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1195811959
// fall through
1195911960
}
1196011961

11961-
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11962-
const auto &M = cast<SCEVConstant>(More)->getAPInt();
11963-
const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11964-
return M - L;
11965-
}
11966-
11967-
SCEV::NoWrapFlags Flags;
11968-
const SCEV *LLess = nullptr, *RLess = nullptr;
11969-
const SCEV *LMore = nullptr, *RMore = nullptr;
11970-
const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11971-
// Compare (X + C1) vs X.
11972-
if (splitBinaryAdd(Less, LLess, RLess, Flags))
11973-
if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11974-
if (RLess == More)
11975-
return -(C1->getAPInt());
11976-
11977-
// Compare X vs (X + C2).
11978-
if (splitBinaryAdd(More, LMore, RMore, Flags))
11979-
if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11980-
if (RMore == Less)
11981-
return C2->getAPInt();
11962+
// Try to cancel out common factors in two add expressions.
11963+
SmallDenseMap<const SCEV *, int, 8> Multiplicity;
11964+
APInt Diff(BW, 0);
11965+
auto Add = [&](const SCEV *S, int Mul) {
11966+
if (auto *C = dyn_cast<SCEVConstant>(S))
11967+
Diff += C->getAPInt() * Mul;
11968+
else
11969+
Multiplicity[S] += Mul;
11970+
};
11971+
auto Decompose = [&](const SCEV *S, int Mul) {
11972+
if (isa<SCEVAddExpr>(S)) {
11973+
for (const SCEV *Op : S->operands())
11974+
Add(Op, Mul);
11975+
} else
11976+
Add(S, Mul);
11977+
};
11978+
Decompose(More, 1);
11979+
Decompose(Less, -1);
1198211980

11983-
// Compare (X + C1) vs (X + C2).
11984-
if (C1 && C2 && RLess == RMore)
11985-
return C2->getAPInt() - C1->getAPInt();
11981+
// Check whether all the non-constants cancel out.
11982+
for (const auto [_, Mul] : Multiplicity)
11983+
if (Mul != 0)
11984+
return std::nullopt;
1198611985

11987-
return std::nullopt;
11986+
return Diff;
1198811987
}
1198911988

1199011989
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,10 +1117,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11171117
LLVMContext C;
11181118
SMDiagnostic Err;
11191119
std::unique_ptr<Module> M = parseAssemblyString(
1120-
"define void @foo(i32 %sz, i32 %pp) { "
1120+
"define void @foo(i32 %sz, i32 %pp, i32 %x) { "
11211121
"entry: "
11221122
" %v0 = add i32 %pp, 0 "
11231123
" %v3 = add i32 %pp, 3 "
1124+
" %vx = add i32 %pp, %x "
1125+
" %vx3 = add i32 %vx, 3 "
11241126
" br label %loop.body "
11251127
"loop.body: "
11261128
" %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
@@ -1141,6 +1143,9 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11411143
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
11421144
auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
11431145
auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
1146+
auto *ScevVX = SE.getSCEV(getInstructionByName(F, "vx")); // (%pp + %x)
1147+
// (%pp + %x + 3)
1148+
auto *ScevVX3 = SE.getSCEV(getInstructionByName(F, "vx3"));
11441149
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
11451150
auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
11461151
auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
@@ -1162,6 +1167,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11621167
EXPECT_EQ(diff(ScevV0, ScevV3), -3);
11631168
EXPECT_EQ(diff(ScevV0, ScevV0), 0);
11641169
EXPECT_EQ(diff(ScevV3, ScevV3), 0);
1170+
EXPECT_EQ(diff(ScevVX3, ScevVX), 3);
11651171
EXPECT_EQ(diff(ScevIV, ScevIV), 0);
11661172
EXPECT_EQ(diff(ScevXA, ScevXB), 0);
11671173
EXPECT_EQ(diff(ScevXA, ScevYY), -3);

0 commit comments

Comments
 (0)