-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SCEV] Look through multiply in computeConstantDifference() #103051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Inside computeConstantDifference(), handle the case where both sides are of the form `C * %x`, in which case we can strip off the common multiplication (as long as we remember to multiply by it for the following difference calculation). There is an obvious alternative implementation here, which would be to directly decompose multiplies inside the "Multiplicity" accumulation. This does work, but I've found this to be both significantly slower (because everything has to work on APInt) and more complex in implementation (e.g. because we now need to match back the new More/Less with an arbitrary factor) without providing more power in practice. As such, I went for the simpler variant here. This is the last step to make computeConstantDifference() sufficiently powerful to replace existing uses of `cast<SCEVConstant>(getMinusSCEV())` with it.
@llvm/pr-subscribers-llvm-analysis Author: Nikita Popov (nikic) ChangesInside computeConstantDifference(), handle the case where both sides are of the form There is an obvious alternative implementation here, which would be to directly decompose multiplies inside the "Multiplicity" accumulation. This does work, but I've found this to be both significantly slower (because everything has to work on APInt) and more complex in implementation (e.g. because we now need to match back the new More/Less with an arbitrary factor) without providing more power in practice. As such, I went for the simpler variant here. This is the last step to make computeConstantDifference() sufficiently powerful to replace existing uses of Full diff: https://github.com/llvm/llvm-project/pull/103051.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 78975bee4d72c4..d03bb0d2156433 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11953,9 +11953,10 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
unsigned BW = getTypeSizeInBits(More->getType());
APInt Diff(BW, 0);
+ APInt DiffMul(BW, 1);
// Try various simplifications to reduce the difference to a constant. Limit
// the number of allowed simplifications to keep compile-time low.
- for (unsigned I = 0; I < 4; ++I) {
+ for (unsigned I = 0; I < 5; ++I) {
if (More == Less)
return Diff;
@@ -11980,15 +11981,36 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
continue;
}
+ // Try to match a common constant multiply.
+ auto MatchConstMul =
+ [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
+ auto *M = dyn_cast<SCEVMulExpr>(S);
+ if (!M || M->getNumOperands() != 2 ||
+ !isa<SCEVConstant>(M->getOperand(0)))
+ return std::nullopt;
+ return {
+ {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
+ };
+ if (auto MatchedMore = MatchConstMul(More)) {
+ if (auto MatchedLess = MatchConstMul(Less)) {
+ if (MatchedMore->second == MatchedLess->second) {
+ More = MatchedMore->first;
+ Less = MatchedLess->first;
+ DiffMul *= MatchedMore->second;
+ continue;
+ }
+ }
+ }
+
// Try to cancel out common factors in two add expressions.
SmallDenseMap<const SCEV *, int, 8> Multiplicity;
auto Add = [&](const SCEV *S, int Mul) {
if (auto *C = dyn_cast<SCEVConstant>(S)) {
if (Mul == 1) {
- Diff += C->getAPInt();
+ Diff += C->getAPInt() * DiffMul;
} else {
assert(Mul == -1);
- Diff -= C->getAPInt();
+ Diff -= C->getAPInt() * DiffMul;
}
} else
Multiplicity[S] += Mul;
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 42aad6ae507bf6..d4d90d80f4cea1 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1142,6 +1142,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
%var = load i32, ptr %ptr
%iv2pvar = add i32 %iv2, %var
%iv2pvarp3 = add i32 %iv2pvar, 3
+ %iv2pvarm3 = mul i32 %iv2pvar, 3
+ %iv2pvarp3m3 = mul i32 %iv2pvarp3, 3
%cmp2 = icmp sle i32 %iv2.next, %sz
br i1 %cmp2, label %loop2.body, label %exit
exit:
@@ -1178,6 +1180,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
// %var + {{3,+,1},+,1}
const SCEV *ScevIV2PVarP3 =
SE.getSCEV(getInstructionByName(F, "iv2pvarp3"));
+ // 3 * (%var + {{0,+,1},+,1})
+ const SCEV *ScevIV2PVarM3 =
+ SE.getSCEV(getInstructionByName(F, "iv2pvarm3"));
+ // 3 * (%var + {{3,+,1},+,1})
+ const SCEV *ScevIV2PVarP3M3 =
+ SE.getSCEV(getInstructionByName(F, "iv2pvarp3m3"));
auto diff = [&SE](const SCEV *LHS, const SCEV *RHS) -> std::optional<int> {
auto ConstantDiffOrNone = computeConstantDifference(SE, LHS, RHS);
@@ -1204,6 +1212,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
EXPECT_EQ(diff(ScevIV2P3, ScevIV2), 3);
EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), -3);
+ EXPECT_EQ(diff(ScevIV2PVarP3M3, ScevIV2PVarM3), 9);
EXPECT_EQ(diff(ScevV0, ScevIV), std::nullopt);
EXPECT_EQ(diff(ScevIVNext, ScevV3), std::nullopt);
EXPECT_EQ(diff(ScevYY, ScevV3), std::nullopt);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Just in case it helps anyone else - I couldn't remember if it was always safe to distribute a multiply over an add in two's complement, so here's the alive2 proof: https://alive2.llvm.org/ce/z/qcDLqq. The nonundef is required, but guaranteed in this case since C is a constant.
Inside computeConstantDifference(), handle the case where both sides are of the form
C * %x
, in which case we can strip off the common multiplication (as long as we remember to multiply by it for the following difference calculation).There is an obvious alternative implementation here, which would be to directly decompose multiplies inside the "Multiplicity" accumulation. This does work, but I've found this to be both significantly slower (because everything has to work on APInt) and more complex in implementation (e.g. because we now need to match back the new More/Less with an arbitrary factor) without providing more power in practice. As such, I went for the simpler variant here.
This is the last step to make computeConstantDifference() sufficiently powerful to replace existing uses of
cast<SCEVConstant>(getMinusSCEV())
with it.