@@ -15034,6 +15034,91 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15034
15034
if (MatchRangeCheckIdiom())
15035
15035
return;
15036
15036
15037
+ // Return true if \p Expr is a MinMax SCEV expression with a constant
15038
+ // operand. If so, return in \p SCTy the SCEV type and in \p RHS the
15039
+ // non-constant operand and in \p LHS the constant operand.
15040
+ auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy,
15041
+ const SCEV *&LHS, const SCEV *&RHS) {
15042
+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15043
+ if (MinMax->getNumOperands() != 2)
15044
+ return false;
15045
+ SCTy = MinMax->getSCEVType();
15046
+ if (!isa<SCEVConstant>(MinMax->getOperand(0)))
15047
+ return false;
15048
+ LHS = MinMax->getOperand(0);
15049
+ RHS = MinMax->getOperand(1);
15050
+ return true;
15051
+ }
15052
+ return false;
15053
+ };
15054
+
15055
+ // Checks whether Expr is a non-negative constant, and Divisor is a positive
15056
+ // constant, and returns their APInt in ExprVal and in DivisorVal.
15057
+ auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15058
+ APInt &ExprVal, APInt &DivisorVal) {
15059
+ if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor))
15060
+ return false;
15061
+ auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15062
+ auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15063
+ if (!ConstExpr || !ConstDivisor)
15064
+ return false;
15065
+ ExprVal = ConstExpr->getAPInt();
15066
+ DivisorVal = ConstDivisor->getAPInt();
15067
+ return true;
15068
+ };
15069
+
15070
+ // Return a new SCEV that modifies \p Expr to the closest number divides by
15071
+ // \p Divisor and greater or equal than Expr.
15072
+ // For now, only handle constant Expr and Divisor.
15073
+ auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15074
+ const SCEV *Divisor) {
15075
+ APInt ExprVal;
15076
+ APInt DivisorVal;
15077
+ if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15078
+ return Expr;
15079
+ APInt Rem = ExprVal.urem(DivisorVal);
15080
+ if (!Rem.isZero())
15081
+ // return the SCEV: Expr + Divisor - Expr % Divisor
15082
+ return getConstant(ExprVal + DivisorVal - Rem);
15083
+ return Expr;
15084
+ };
15085
+
15086
+ // Return a new SCEV that modifies \p Expr to the closest number divides by
15087
+ // \p Divisor and less or equal than Expr.
15088
+ // For now, only handle constant Expr and Divisor.
15089
+ auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15090
+ const SCEV *Divisor) {
15091
+ APInt ExprVal;
15092
+ APInt DivisorVal;
15093
+ if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15094
+ return Expr;
15095
+ APInt Rem = ExprVal.urem(DivisorVal);
15096
+ // return the SCEV: Expr - Expr % Divisor
15097
+ return getConstant(ExprVal - Rem);
15098
+ };
15099
+
15100
+ // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15101
+ // recursively. This is done by aligning up/down the constant value to the
15102
+ // Divisor.
15103
+ std::function<const SCEV *(const SCEV *, const SCEV *)>
15104
+ ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15105
+ const SCEV *Divisor) {
15106
+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15107
+ SCEVTypes SCTy;
15108
+ if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS))
15109
+ return MinMaxExpr;
15110
+ auto IsMin =
15111
+ isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15112
+ assert(isKnownNonNegative(MinMaxLHS) &&
15113
+ "Expected non-negative operand!");
15114
+ auto *DivisibleExpr =
15115
+ IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15116
+ : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15117
+ SmallVector<const SCEV *> Ops = {
15118
+ ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15119
+ return getMinMaxExpr(SCTy, Ops);
15120
+ };
15121
+
15037
15122
// If we have LHS == 0, check if LHS is computing a property of some unknown
15038
15123
// SCEV %v which we can rewrite %v to express explicitly.
15039
15124
const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
@@ -15045,7 +15130,12 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15045
15130
const SCEV *URemRHS = nullptr;
15046
15131
if (matchURem(LHS, URemLHS, URemRHS)) {
15047
15132
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15048
- const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
15133
+ auto I = RewriteMap.find(LHSUnknown);
15134
+ const SCEV *RewrittenLHS =
15135
+ I != RewriteMap.end() ? I->second : LHSUnknown;
15136
+ RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15137
+ const auto *Multiple =
15138
+ getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15049
15139
RewriteMap[LHSUnknown] = Multiple;
15050
15140
ExprsToRewrite.push_back(LHSUnknown);
15051
15141
return;
@@ -15068,48 +15158,128 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15068
15158
auto I = RewriteMap.find(LHS);
15069
15159
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
15070
15160
15161
+ // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15162
+ // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15163
+ // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15164
+ // /u B) * B was found, and return the divisor B in \p DividesBy. For
15165
+ // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15166
+ // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15167
+ // DividesBy.
15168
+ std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15169
+ [&](const SCEV *Expr, const SCEV *&DividesBy) {
15170
+ if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15171
+ if (Mul->getNumOperands() != 2)
15172
+ return false;
15173
+ auto *MulLHS = Mul->getOperand(0);
15174
+ auto *MulRHS = Mul->getOperand(1);
15175
+ if (isa<SCEVConstant>(MulLHS))
15176
+ std::swap(MulLHS, MulRHS);
15177
+ if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS)) {
15178
+ if (Div->getOperand(1) == MulRHS) {
15179
+ DividesBy = MulRHS;
15180
+ return true;
15181
+ }
15182
+ }
15183
+ }
15184
+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15185
+ return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15186
+ HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15187
+ }
15188
+ return false;
15189
+ };
15190
+
15191
+ // Return true if Expr known to divide by \p DividesBy.
15192
+ std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15193
+ [&](const SCEV *Expr, const SCEV *DividesBy) {
15194
+ if (getURemExpr(Expr, DividesBy)->isZero())
15195
+ return true;
15196
+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15197
+ return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15198
+ IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15199
+ }
15200
+ return false;
15201
+ };
15202
+
15203
+ const SCEV *DividesBy = nullptr;
15204
+ if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15205
+ // Check that the whole expression is divided by DividesBy
15206
+ DividesBy =
15207
+ IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15208
+
15071
15209
const SCEV *RewrittenRHS = nullptr;
15072
15210
switch (Predicate) {
15073
15211
case CmpInst::ICMP_ULT: {
15074
15212
if (RHS->getType()->isPointerTy())
15075
15213
break;
15076
15214
const SCEV *One = getOne(RHS->getType());
15077
- RewrittenRHS =
15078
- getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One));
15215
+ auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One);
15216
+ ModifiedRHS =
15217
+ DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15218
+ : ModifiedRHS;
15219
+ RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
15079
15220
break;
15080
15221
}
15081
- case CmpInst::ICMP_SLT:
15082
- RewrittenRHS =
15083
- getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
15222
+ case CmpInst::ICMP_SLT: {
15223
+ auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType()));
15224
+ ModifiedRHS =
15225
+ DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15226
+ : ModifiedRHS;
15227
+ RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
15084
15228
break;
15085
- case CmpInst::ICMP_ULE:
15086
- RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
15229
+ }
15230
+ case CmpInst::ICMP_ULE: {
15231
+ auto *ModifiedRHS =
15232
+ DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15233
+ RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
15087
15234
break;
15088
- case CmpInst::ICMP_SLE:
15089
- RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
15235
+ }
15236
+ case CmpInst::ICMP_SLE: {
15237
+ auto *ModifiedRHS =
15238
+ DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15239
+ RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
15090
15240
break;
15091
- case CmpInst::ICMP_UGT:
15092
- RewrittenRHS =
15093
- getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
15241
+ }
15242
+ case CmpInst::ICMP_UGT: {
15243
+ auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
15244
+ ModifiedRHS = DividesBy
15245
+ ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15246
+ : ModifiedRHS;
15247
+ RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
15094
15248
break;
15095
- case CmpInst::ICMP_SGT:
15096
- RewrittenRHS =
15097
- getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
15249
+ }
15250
+ case CmpInst::ICMP_SGT: {
15251
+ auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
15252
+ ModifiedRHS = DividesBy
15253
+ ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15254
+ : ModifiedRHS;
15255
+ RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
15098
15256
break;
15099
- case CmpInst::ICMP_UGE:
15100
- RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
15257
+ }
15258
+ case CmpInst::ICMP_UGE: {
15259
+ auto *ModifiedRHS =
15260
+ DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15261
+ RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
15101
15262
break;
15102
- case CmpInst::ICMP_SGE:
15103
- RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
15263
+ }
15264
+ case CmpInst::ICMP_SGE: {
15265
+ auto *ModifiedRHS =
15266
+ DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15267
+ RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
15104
15268
break;
15269
+ }
15105
15270
case CmpInst::ICMP_EQ:
15106
15271
if (isa<SCEVConstant>(RHS))
15107
15272
RewrittenRHS = RHS;
15108
15273
break;
15109
15274
case CmpInst::ICMP_NE:
15110
15275
if (isa<SCEVConstant>(RHS) &&
15111
- cast<SCEVConstant>(RHS)->getValue()->isNullValue())
15112
- RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
15276
+ cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15277
+ auto *ModifiedRHS = getOne(RHS->getType());
15278
+ ModifiedRHS = DividesBy
15279
+ ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15280
+ : ModifiedRHS;
15281
+ RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
15282
+ }
15113
15283
break;
15114
15284
default:
15115
15285
break;
0 commit comments