Skip to content

Commit 85c5265

Browse files
authored
[SCEV] Unify and optimize constant folding (NFC) (#101473)
Add a common constantFoldAndGroupOps() helper that takes care of constant folding and grouping transforms that are common to all nary ops. This moves the constant folding prior to grouping, which is more efficient, and excludes any constant from the sort. The constant folding has hooks for folding, identity constants and absorber constants. This gives a compile-time improvement for SCEV-heavy workloads like lencod.
1 parent 10df320 commit 85c5265

File tree

1 file changed

+92
-100
lines changed

1 file changed

+92
-100
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 92 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,49 @@ static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
831831
});
832832
}
833833

834+
/// Performs a number of common optimizations on the passed \p Ops. If the
835+
/// whole expression reduces down to a single operand, it will be returned.
836+
///
837+
/// The following optimizations are performed:
838+
/// * Fold constants using the \p Fold function.
839+
/// * Remove identity constants satisfying \p IsIdentity.
840+
/// * If a constant satisfies \p IsAbsorber, return it.
841+
/// * Sort operands by complexity.
842+
template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843+
static const SCEV *
844+
constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
845+
SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
846+
IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847+
const SCEVConstant *Folded = nullptr;
848+
for (unsigned Idx = 0; Idx < Ops.size();) {
849+
const SCEV *Op = Ops[Idx];
850+
if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851+
if (!Folded)
852+
Folded = C;
853+
else
854+
Folded = cast<SCEVConstant>(
855+
SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856+
Ops.erase(Ops.begin() + Idx);
857+
continue;
858+
}
859+
++Idx;
860+
}
861+
862+
if (Ops.empty()) {
863+
assert(Folded && "Must have folded value");
864+
return Folded;
865+
}
866+
867+
if (Folded && IsAbsorber(Folded->getAPInt()))
868+
return Folded;
869+
870+
GroupByComplexity(Ops, &LI, DT);
871+
if (Folded && !IsIdentity(Folded->getAPInt()))
872+
Ops.insert(Ops.begin(), Folded);
873+
874+
return Ops.size() == 1 ? Ops[0] : nullptr;
875+
}
876+
834877
//===----------------------------------------------------------------------===//
835878
// Simple SCEV method implementations
836879
//===----------------------------------------------------------------------===//
@@ -2504,30 +2547,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
25042547
assert(NumPtrs <= 1 && "add has at most one pointer operand");
25052548
#endif
25062549

2507-
// Sort by complexity, this groups all similar expression types together.
2508-
GroupByComplexity(Ops, &LI, DT);
2509-
2510-
// If there are any constants, fold them together.
2511-
unsigned Idx = 0;
2512-
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2513-
++Idx;
2514-
assert(Idx < Ops.size());
2515-
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2516-
// We found two constants, fold them together!
2517-
Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2518-
if (Ops.size() == 2) return Ops[0];
2519-
Ops.erase(Ops.begin()+1); // Erase the folded element
2520-
LHSC = cast<SCEVConstant>(Ops[0]);
2521-
}
2522-
2523-
// If we are left with a constant zero being added, strip it off.
2524-
if (LHSC->getValue()->isZero()) {
2525-
Ops.erase(Ops.begin());
2526-
--Idx;
2527-
}
2550+
const SCEV *Folded = constantFoldAndGroupOps(
2551+
*this, LI, DT, Ops,
2552+
[](const APInt &C1, const APInt &C2) { return C1 + C2; },
2553+
[](const APInt &C) { return C.isZero(); }, // identity
2554+
[](const APInt &C) { return false; }); // absorber
2555+
if (Folded)
2556+
return Folded;
25282557

2529-
if (Ops.size() == 1) return Ops[0];
2530-
}
2558+
unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
25312559

25322560
// Delay expensive flag strengthening until necessary.
25332561
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3097,35 +3125,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
30973125
"SCEVMulExpr operand types don't match!");
30983126
#endif
30993127

3100-
// Sort by complexity, this groups all similar expression types together.
3101-
GroupByComplexity(Ops, &LI, DT);
3102-
3103-
// If there are any constants, fold them together.
3104-
unsigned Idx = 0;
3105-
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3106-
++Idx;
3107-
assert(Idx < Ops.size());
3108-
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3109-
// We found two constants, fold them together!
3110-
Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3111-
if (Ops.size() == 2) return Ops[0];
3112-
Ops.erase(Ops.begin()+1); // Erase the folded element
3113-
LHSC = cast<SCEVConstant>(Ops[0]);
3114-
}
3115-
3116-
// If we have a multiply of zero, it will always be zero.
3117-
if (LHSC->getValue()->isZero())
3118-
return LHSC;
3119-
3120-
// If we are left with a constant one being multiplied, strip it off.
3121-
if (LHSC->getValue()->isOne()) {
3122-
Ops.erase(Ops.begin());
3123-
--Idx;
3124-
}
3125-
3126-
if (Ops.size() == 1)
3127-
return Ops[0];
3128-
}
3128+
const SCEV *Folded = constantFoldAndGroupOps(
3129+
*this, LI, DT, Ops,
3130+
[](const APInt &C1, const APInt &C2) { return C1 * C2; },
3131+
[](const APInt &C) { return C.isOne(); }, // identity
3132+
[](const APInt &C) { return C.isZero(); }); // absorber
3133+
if (Folded)
3134+
return Folded;
31293135

31303136
// Delay expensive flag strengthening until necessary.
31313137
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3202,6 +3208,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
32023208
}
32033209

32043210
// Skip over the add expression until we get to a multiply.
3211+
unsigned Idx = 0;
32053212
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
32063213
++Idx;
32073214

@@ -3829,61 +3836,46 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
38293836
bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
38303837
bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
38313838

3832-
// Sort by complexity, this groups all similar expression types together.
3833-
GroupByComplexity(Ops, &LI, DT);
3839+
const SCEV *Folded = constantFoldAndGroupOps(
3840+
*this, LI, DT, Ops,
3841+
[&](const APInt &C1, const APInt &C2) {
3842+
switch (Kind) {
3843+
case scSMaxExpr:
3844+
return APIntOps::smax(C1, C2);
3845+
case scSMinExpr:
3846+
return APIntOps::smin(C1, C2);
3847+
case scUMaxExpr:
3848+
return APIntOps::umax(C1, C2);
3849+
case scUMinExpr:
3850+
return APIntOps::umin(C1, C2);
3851+
default:
3852+
llvm_unreachable("Unknown SCEV min/max opcode");
3853+
}
3854+
},
3855+
[&](const APInt &C) {
3856+
// identity
3857+
if (IsMax)
3858+
return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3859+
else
3860+
return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3861+
},
3862+
[&](const APInt &C) {
3863+
// absorber
3864+
if (IsMax)
3865+
return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3866+
else
3867+
return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3868+
});
3869+
if (Folded)
3870+
return Folded;
38343871

38353872
// Check if we have created the same expression before.
38363873
if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
38373874
return S;
38383875
}
38393876

3840-
// If there are any constants, fold them together.
3841-
unsigned Idx = 0;
3842-
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3843-
++Idx;
3844-
assert(Idx < Ops.size());
3845-
auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3846-
switch (Kind) {
3847-
case scSMaxExpr:
3848-
return APIntOps::smax(LHS, RHS);
3849-
case scSMinExpr:
3850-
return APIntOps::smin(LHS, RHS);
3851-
case scUMaxExpr:
3852-
return APIntOps::umax(LHS, RHS);
3853-
case scUMinExpr:
3854-
return APIntOps::umin(LHS, RHS);
3855-
default:
3856-
llvm_unreachable("Unknown SCEV min/max opcode");
3857-
}
3858-
};
3859-
3860-
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3861-
// We found two constants, fold them together!
3862-
ConstantInt *Fold = ConstantInt::get(
3863-
getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3864-
Ops[0] = getConstant(Fold);
3865-
Ops.erase(Ops.begin()+1); // Erase the folded element
3866-
if (Ops.size() == 1) return Ops[0];
3867-
LHSC = cast<SCEVConstant>(Ops[0]);
3868-
}
3869-
3870-
bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3871-
bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3872-
3873-
if (IsMax ? IsMinV : IsMaxV) {
3874-
// If we are left with a constant minimum(/maximum)-int, strip it off.
3875-
Ops.erase(Ops.begin());
3876-
--Idx;
3877-
} else if (IsMax ? IsMaxV : IsMinV) {
3878-
// If we have a max(/min) with a constant maximum(/minimum)-int,
3879-
// it will always be the extremum.
3880-
return LHSC;
3881-
}
3882-
3883-
if (Ops.size() == 1) return Ops[0];
3884-
}
3885-
38863877
// Find the first operation of the same kind
3878+
unsigned Idx = 0;
38873879
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
38883880
++Idx;
38893881

0 commit comments

Comments
 (0)