@@ -831,6 +831,49 @@ static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
831
831
});
832
832
}
833
833
834
+ /// Performs a number of common optimizations on the passed \p Ops. If the
835
+ /// whole expression reduces down to a single operand, it will be returned.
836
+ ///
837
+ /// The following optimizations are performed:
838
+ /// * Fold constants using the \p Fold function.
839
+ /// * Remove identity constants satisfying \p IsIdentity.
840
+ /// * If a constant satisfies \p IsAbsorber, return it.
841
+ /// * Sort operands by complexity.
842
+ template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843
+ static const SCEV *
844
+ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
845
+ SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
846
+ IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847
+ const SCEVConstant *Folded = nullptr;
848
+ for (unsigned Idx = 0; Idx < Ops.size();) {
849
+ const SCEV *Op = Ops[Idx];
850
+ if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851
+ if (!Folded)
852
+ Folded = C;
853
+ else
854
+ Folded = cast<SCEVConstant>(
855
+ SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856
+ Ops.erase(Ops.begin() + Idx);
857
+ continue;
858
+ }
859
+ ++Idx;
860
+ }
861
+
862
+ if (Ops.empty()) {
863
+ assert(Folded && "Must have folded value");
864
+ return Folded;
865
+ }
866
+
867
+ if (Folded && IsAbsorber(Folded->getAPInt()))
868
+ return Folded;
869
+
870
+ GroupByComplexity(Ops, &LI, DT);
871
+ if (Folded && !IsIdentity(Folded->getAPInt()))
872
+ Ops.insert(Ops.begin(), Folded);
873
+
874
+ return Ops.size() == 1 ? Ops[0] : nullptr;
875
+ }
876
+
834
877
//===----------------------------------------------------------------------===//
835
878
// Simple SCEV method implementations
836
879
//===----------------------------------------------------------------------===//
@@ -2504,30 +2547,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2504
2547
assert(NumPtrs <= 1 && "add has at most one pointer operand");
2505
2548
#endif
2506
2549
2507
- // Sort by complexity, this groups all similar expression types together.
2508
- GroupByComplexity(Ops, &LI, DT);
2509
-
2510
- // If there are any constants, fold them together.
2511
- unsigned Idx = 0;
2512
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2513
- ++Idx;
2514
- assert(Idx < Ops.size());
2515
- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2516
- // We found two constants, fold them together!
2517
- Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2518
- if (Ops.size() == 2) return Ops[0];
2519
- Ops.erase(Ops.begin()+1); // Erase the folded element
2520
- LHSC = cast<SCEVConstant>(Ops[0]);
2521
- }
2522
-
2523
- // If we are left with a constant zero being added, strip it off.
2524
- if (LHSC->getValue()->isZero()) {
2525
- Ops.erase(Ops.begin());
2526
- --Idx;
2527
- }
2550
+ const SCEV *Folded = constantFoldAndGroupOps(
2551
+ *this, LI, DT, Ops,
2552
+ [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2553
+ [](const APInt &C) { return C.isZero(); }, // identity
2554
+ [](const APInt &C) { return false; }); // absorber
2555
+ if (Folded)
2556
+ return Folded;
2528
2557
2529
- if (Ops.size() == 1) return Ops[0];
2530
- }
2558
+ unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2531
2559
2532
2560
// Delay expensive flag strengthening until necessary.
2533
2561
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3097,35 +3125,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
3097
3125
"SCEVMulExpr operand types don't match!");
3098
3126
#endif
3099
3127
3100
- // Sort by complexity, this groups all similar expression types together.
3101
- GroupByComplexity(Ops, &LI, DT);
3102
-
3103
- // If there are any constants, fold them together.
3104
- unsigned Idx = 0;
3105
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3106
- ++Idx;
3107
- assert(Idx < Ops.size());
3108
- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3109
- // We found two constants, fold them together!
3110
- Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3111
- if (Ops.size() == 2) return Ops[0];
3112
- Ops.erase(Ops.begin()+1); // Erase the folded element
3113
- LHSC = cast<SCEVConstant>(Ops[0]);
3114
- }
3115
-
3116
- // If we have a multiply of zero, it will always be zero.
3117
- if (LHSC->getValue()->isZero())
3118
- return LHSC;
3119
-
3120
- // If we are left with a constant one being multiplied, strip it off.
3121
- if (LHSC->getValue()->isOne()) {
3122
- Ops.erase(Ops.begin());
3123
- --Idx;
3124
- }
3125
-
3126
- if (Ops.size() == 1)
3127
- return Ops[0];
3128
- }
3128
+ const SCEV *Folded = constantFoldAndGroupOps(
3129
+ *this, LI, DT, Ops,
3130
+ [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3131
+ [](const APInt &C) { return C.isOne(); }, // identity
3132
+ [](const APInt &C) { return C.isZero(); }); // absorber
3133
+ if (Folded)
3134
+ return Folded;
3129
3135
3130
3136
// Delay expensive flag strengthening until necessary.
3131
3137
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3202,6 +3208,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
3202
3208
}
3203
3209
3204
3210
// Skip over the add expression until we get to a multiply.
3211
+ unsigned Idx = 0;
3205
3212
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3206
3213
++Idx;
3207
3214
@@ -3829,61 +3836,46 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
3829
3836
bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3830
3837
bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3831
3838
3832
- // Sort by complexity, this groups all similar expression types together.
3833
- GroupByComplexity(Ops, &LI, DT);
3839
+ const SCEV *Folded = constantFoldAndGroupOps(
3840
+ *this, LI, DT, Ops,
3841
+ [&](const APInt &C1, const APInt &C2) {
3842
+ switch (Kind) {
3843
+ case scSMaxExpr:
3844
+ return APIntOps::smax(C1, C2);
3845
+ case scSMinExpr:
3846
+ return APIntOps::smin(C1, C2);
3847
+ case scUMaxExpr:
3848
+ return APIntOps::umax(C1, C2);
3849
+ case scUMinExpr:
3850
+ return APIntOps::umin(C1, C2);
3851
+ default:
3852
+ llvm_unreachable("Unknown SCEV min/max opcode");
3853
+ }
3854
+ },
3855
+ [&](const APInt &C) {
3856
+ // identity
3857
+ if (IsMax)
3858
+ return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3859
+ else
3860
+ return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3861
+ },
3862
+ [&](const APInt &C) {
3863
+ // absorber
3864
+ if (IsMax)
3865
+ return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3866
+ else
3867
+ return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3868
+ });
3869
+ if (Folded)
3870
+ return Folded;
3834
3871
3835
3872
// Check if we have created the same expression before.
3836
3873
if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3837
3874
return S;
3838
3875
}
3839
3876
3840
- // If there are any constants, fold them together.
3841
- unsigned Idx = 0;
3842
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3843
- ++Idx;
3844
- assert(Idx < Ops.size());
3845
- auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3846
- switch (Kind) {
3847
- case scSMaxExpr:
3848
- return APIntOps::smax(LHS, RHS);
3849
- case scSMinExpr:
3850
- return APIntOps::smin(LHS, RHS);
3851
- case scUMaxExpr:
3852
- return APIntOps::umax(LHS, RHS);
3853
- case scUMinExpr:
3854
- return APIntOps::umin(LHS, RHS);
3855
- default:
3856
- llvm_unreachable("Unknown SCEV min/max opcode");
3857
- }
3858
- };
3859
-
3860
- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3861
- // We found two constants, fold them together!
3862
- ConstantInt *Fold = ConstantInt::get(
3863
- getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3864
- Ops[0] = getConstant(Fold);
3865
- Ops.erase(Ops.begin()+1); // Erase the folded element
3866
- if (Ops.size() == 1) return Ops[0];
3867
- LHSC = cast<SCEVConstant>(Ops[0]);
3868
- }
3869
-
3870
- bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3871
- bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3872
-
3873
- if (IsMax ? IsMinV : IsMaxV) {
3874
- // If we are left with a constant minimum(/maximum)-int, strip it off.
3875
- Ops.erase(Ops.begin());
3876
- --Idx;
3877
- } else if (IsMax ? IsMaxV : IsMinV) {
3878
- // If we have a max(/min) with a constant maximum(/minimum)-int,
3879
- // it will always be the extremum.
3880
- return LHSC;
3881
- }
3882
-
3883
- if (Ops.size() == 1) return Ops[0];
3884
- }
3885
-
3886
3877
// Find the first operation of the same kind
3878
+ unsigned Idx = 0;
3887
3879
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3888
3880
++Idx;
3889
3881
0 commit comments