Skip to content

Commit 6da3361

Browse files
authored
[SCEV] Look through multiply in computeConstantDifference() (#103051)
Inside computeConstantDifference(), handle the case where both sides are of the form `C * %x`, in which case we can strip off the common multiplication (as long as we remember to multiply by it for the following difference calculation). There is an obvious alternative implementation here, which would be to directly decompose multiplies inside the "Multiplicity" accumulation. This does work, but I've found this to be both significantly slower (because everything has to work on APInt) and more complex in implementation (e.g. because we now need to match back the new More/Less with an arbitrary factor) without providing more power in practice. As such, I went for the simpler variant here. This is the last step to make computeConstantDifference() sufficiently powerful to replace existing uses of `cast<SCEVConstant>(getMinusSCEV())` with it.
1 parent 241f9e7 commit 6da3361

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11953,9 +11953,10 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1195311953

1195411954
unsigned BW = getTypeSizeInBits(More->getType());
1195511955
APInt Diff(BW, 0);
11956+
APInt DiffMul(BW, 1);
1195611957
// Try various simplifications to reduce the difference to a constant. Limit
1195711958
// the number of allowed simplifications to keep compile-time low.
11958-
for (unsigned I = 0; I < 4; ++I) {
11959+
for (unsigned I = 0; I < 5; ++I) {
1195911960
if (More == Less)
1196011961
return Diff;
1196111962

@@ -11980,15 +11981,36 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1198011981
continue;
1198111982
}
1198211983

11984+
// Try to match a common constant multiply.
11985+
auto MatchConstMul =
11986+
[](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
11987+
auto *M = dyn_cast<SCEVMulExpr>(S);
11988+
if (!M || M->getNumOperands() != 2 ||
11989+
!isa<SCEVConstant>(M->getOperand(0)))
11990+
return std::nullopt;
11991+
return {
11992+
{M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
11993+
};
11994+
if (auto MatchedMore = MatchConstMul(More)) {
11995+
if (auto MatchedLess = MatchConstMul(Less)) {
11996+
if (MatchedMore->second == MatchedLess->second) {
11997+
More = MatchedMore->first;
11998+
Less = MatchedLess->first;
11999+
DiffMul *= MatchedMore->second;
12000+
continue;
12001+
}
12002+
}
12003+
}
12004+
1198312005
// Try to cancel out common factors in two add expressions.
1198412006
SmallDenseMap<const SCEV *, int, 8> Multiplicity;
1198512007
auto Add = [&](const SCEV *S, int Mul) {
1198612008
if (auto *C = dyn_cast<SCEVConstant>(S)) {
1198712009
if (Mul == 1) {
11988-
Diff += C->getAPInt();
12010+
Diff += C->getAPInt() * DiffMul;
1198912011
} else {
1199012012
assert(Mul == -1);
11991-
Diff -= C->getAPInt();
12013+
Diff -= C->getAPInt() * DiffMul;
1199212014
}
1199312015
} else
1199412016
Multiplicity[S] += Mul;

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11421142
%var = load i32, ptr %ptr
11431143
%iv2pvar = add i32 %iv2, %var
11441144
%iv2pvarp3 = add i32 %iv2pvar, 3
1145+
%iv2pvarm3 = mul i32 %iv2pvar, 3
1146+
%iv2pvarp3m3 = mul i32 %iv2pvarp3, 3
11451147
%cmp2 = icmp sle i32 %iv2.next, %sz
11461148
br i1 %cmp2, label %loop2.body, label %exit
11471149
exit:
@@ -1178,6 +1180,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11781180
// %var + {{3,+,1},+,1}
11791181
const SCEV *ScevIV2PVarP3 =
11801182
SE.getSCEV(getInstructionByName(F, "iv2pvarp3"));
1183+
// 3 * (%var + {{0,+,1},+,1})
1184+
const SCEV *ScevIV2PVarM3 =
1185+
SE.getSCEV(getInstructionByName(F, "iv2pvarm3"));
1186+
// 3 * (%var + {{3,+,1},+,1})
1187+
const SCEV *ScevIV2PVarP3M3 =
1188+
SE.getSCEV(getInstructionByName(F, "iv2pvarp3m3"));
11811189

11821190
auto diff = [&SE](const SCEV *LHS, const SCEV *RHS) -> std::optional<int> {
11831191
auto ConstantDiffOrNone = computeConstantDifference(SE, LHS, RHS);
@@ -1204,6 +1212,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
12041212
EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
12051213
EXPECT_EQ(diff(ScevIV2P3, ScevIV2), 3);
12061214
EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), -3);
1215+
EXPECT_EQ(diff(ScevIV2PVarP3M3, ScevIV2PVarM3), 9);
12071216
EXPECT_EQ(diff(ScevV0, ScevIV), std::nullopt);
12081217
EXPECT_EQ(diff(ScevIVNext, ScevV3), std::nullopt);
12091218
EXPECT_EQ(diff(ScevYY, ScevV3), std::nullopt);

0 commit comments

Comments
 (0)