Skip to content

Commit 219ba2f

Browse files
committed
[SCEV] Preserve divisibility and min/max information in applyLoopGuards
applyLoopGuards doesn't always preserve information when there are multiple assumes. This patch tries to deal with multiple assumes regarding a SCEV's divisibility and min/max values, and rewrite it into a SCEV that still preserves all of the information. For example, let the trip count of the loop be TC. Consider the 3 following assumes: 1. __builtin_assume(TC % 8 == 0); 2. __builtin_assume(TC > 0); 3. __builtin_assume(TC < 100); Before this patch, depending on the assume processing order applyLoopGuards could create the following SCEV: max(min((8 * (TC / 8)) , 99), 1) Looking at this SCEV, it doesn't preserve the divisibility by 8 information. After this patch, depending on the assume processing order applyLoopGuards could create the following SCEV: max(min((8 * (TC / 8)) , 96), 8) By aligning up 1 to 8, and aligning down 99 to 96, the new SCEV still preserves all of the original assumes. Differential Revision: https://reviews.llvm.org/D141850
1 parent 74565c3 commit 219ba2f

File tree

3 files changed

+237
-29
lines changed

3 files changed

+237
-29
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 192 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15034,6 +15034,91 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1503415034
if (MatchRangeCheckIdiom())
1503515035
return;
1503615036

15037+
// Return true if \p Expr is a MinMax SCEV expression with a constant
15038+
// operand. If so, return in \p SCTy the SCEV type and in \p RHS the
15039+
// non-constant operand and in \p LHS the constant operand.
15040+
auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy,
15041+
const SCEV *&LHS, const SCEV *&RHS) {
15042+
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15043+
if (MinMax->getNumOperands() != 2)
15044+
return false;
15045+
SCTy = MinMax->getSCEVType();
15046+
if (!isa<SCEVConstant>(MinMax->getOperand(0)))
15047+
return false;
15048+
LHS = MinMax->getOperand(0);
15049+
RHS = MinMax->getOperand(1);
15050+
return true;
15051+
}
15052+
return false;
15053+
};
15054+
15055+
// Checks whether Expr is a non-negative constant, and Divisor is a positive
15056+
// constant, and returns their APInt in ExprVal and in DivisorVal.
15057+
auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15058+
APInt &ExprVal, APInt &DivisorVal) {
15059+
if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor))
15060+
return false;
15061+
auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15062+
auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15063+
if (!ConstExpr || !ConstDivisor)
15064+
return false;
15065+
ExprVal = ConstExpr->getAPInt();
15066+
DivisorVal = ConstDivisor->getAPInt();
15067+
return true;
15068+
};
15069+
15070+
// Return a new SCEV that modifies \p Expr to the closest number divides by
15071+
// \p Divisor and greater or equal than Expr.
15072+
// For now, only handle constant Expr and Divisor.
15073+
auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15074+
const SCEV *Divisor) {
15075+
APInt ExprVal;
15076+
APInt DivisorVal;
15077+
if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15078+
return Expr;
15079+
APInt Rem = ExprVal.urem(DivisorVal);
15080+
if (!Rem.isZero())
15081+
// return the SCEV: Expr + Divisor - Expr % Divisor
15082+
return getConstant(ExprVal + DivisorVal - Rem);
15083+
return Expr;
15084+
};
15085+
15086+
// Return a new SCEV that modifies \p Expr to the closest number divides by
15087+
// \p Divisor and less or equal than Expr.
15088+
// For now, only handle constant Expr and Divisor.
15089+
auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15090+
const SCEV *Divisor) {
15091+
APInt ExprVal;
15092+
APInt DivisorVal;
15093+
if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15094+
return Expr;
15095+
APInt Rem = ExprVal.urem(DivisorVal);
15096+
// return the SCEV: Expr - Expr % Divisor
15097+
return getConstant(ExprVal - Rem);
15098+
};
15099+
15100+
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15101+
// recursively. This is done by aligning up/down the constant value to the
15102+
// Divisor.
15103+
std::function<const SCEV *(const SCEV *, const SCEV *)>
15104+
ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15105+
const SCEV *Divisor) {
15106+
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15107+
SCEVTypes SCTy;
15108+
if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS))
15109+
return MinMaxExpr;
15110+
auto IsMin =
15111+
isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15112+
assert(isKnownNonNegative(MinMaxLHS) &&
15113+
"Expected non-negative operand!");
15114+
auto *DivisibleExpr =
15115+
IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15116+
: GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15117+
SmallVector<const SCEV *> Ops = {
15118+
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15119+
return getMinMaxExpr(SCTy, Ops);
15120+
};
15121+
1503715122
// If we have LHS == 0, check if LHS is computing a property of some unknown
1503815123
// SCEV %v which we can rewrite %v to express explicitly.
1503915124
const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
@@ -15045,7 +15130,12 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1504515130
const SCEV *URemRHS = nullptr;
1504615131
if (matchURem(LHS, URemLHS, URemRHS)) {
1504715132
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15048-
const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
15133+
auto I = RewriteMap.find(LHSUnknown);
15134+
const SCEV *RewrittenLHS =
15135+
I != RewriteMap.end() ? I->second : LHSUnknown;
15136+
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15137+
const auto *Multiple =
15138+
getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
1504915139
RewriteMap[LHSUnknown] = Multiple;
1505015140
ExprsToRewrite.push_back(LHSUnknown);
1505115141
return;
@@ -15068,48 +15158,128 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1506815158
auto I = RewriteMap.find(LHS);
1506915159
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
1507015160

15161+
// Check for the SCEV expression (A /u B) * B while B is a constant, inside
15162+
// \p Expr. The check is done recuresively on \p Expr, which is assumed to
15163+
// be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15164+
// /u B) * B was found, and return the divisor B in \p DividesBy. For
15165+
// example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15166+
// (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15167+
// DividesBy.
15168+
std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15169+
[&](const SCEV *Expr, const SCEV *&DividesBy) {
15170+
if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15171+
if (Mul->getNumOperands() != 2)
15172+
return false;
15173+
auto *MulLHS = Mul->getOperand(0);
15174+
auto *MulRHS = Mul->getOperand(1);
15175+
if (isa<SCEVConstant>(MulLHS))
15176+
std::swap(MulLHS, MulRHS);
15177+
if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS)) {
15178+
if (Div->getOperand(1) == MulRHS) {
15179+
DividesBy = MulRHS;
15180+
return true;
15181+
}
15182+
}
15183+
}
15184+
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15185+
return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15186+
HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15187+
}
15188+
return false;
15189+
};
15190+
15191+
// Return true if Expr known to divide by \p DividesBy.
15192+
std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15193+
[&](const SCEV *Expr, const SCEV *DividesBy) {
15194+
if (getURemExpr(Expr, DividesBy)->isZero())
15195+
return true;
15196+
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15197+
return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15198+
IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15199+
}
15200+
return false;
15201+
};
15202+
15203+
const SCEV *DividesBy = nullptr;
15204+
if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15205+
// Check that the whole expression is divided by DividesBy
15206+
DividesBy =
15207+
IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15208+
1507115209
const SCEV *RewrittenRHS = nullptr;
1507215210
switch (Predicate) {
1507315211
case CmpInst::ICMP_ULT: {
1507415212
if (RHS->getType()->isPointerTy())
1507515213
break;
1507615214
const SCEV *One = getOne(RHS->getType());
15077-
RewrittenRHS =
15078-
getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One));
15215+
auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One);
15216+
ModifiedRHS =
15217+
DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15218+
: ModifiedRHS;
15219+
RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
1507915220
break;
1508015221
}
15081-
case CmpInst::ICMP_SLT:
15082-
RewrittenRHS =
15083-
getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
15222+
case CmpInst::ICMP_SLT: {
15223+
auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType()));
15224+
ModifiedRHS =
15225+
DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15226+
: ModifiedRHS;
15227+
RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
1508415228
break;
15085-
case CmpInst::ICMP_ULE:
15086-
RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
15229+
}
15230+
case CmpInst::ICMP_ULE: {
15231+
auto *ModifiedRHS =
15232+
DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15233+
RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
1508715234
break;
15088-
case CmpInst::ICMP_SLE:
15089-
RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
15235+
}
15236+
case CmpInst::ICMP_SLE: {
15237+
auto *ModifiedRHS =
15238+
DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15239+
RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
1509015240
break;
15091-
case CmpInst::ICMP_UGT:
15092-
RewrittenRHS =
15093-
getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
15241+
}
15242+
case CmpInst::ICMP_UGT: {
15243+
auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
15244+
ModifiedRHS = DividesBy
15245+
? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15246+
: ModifiedRHS;
15247+
RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
1509415248
break;
15095-
case CmpInst::ICMP_SGT:
15096-
RewrittenRHS =
15097-
getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
15249+
}
15250+
case CmpInst::ICMP_SGT: {
15251+
auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
15252+
ModifiedRHS = DividesBy
15253+
? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15254+
: ModifiedRHS;
15255+
RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
1509815256
break;
15099-
case CmpInst::ICMP_UGE:
15100-
RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
15257+
}
15258+
case CmpInst::ICMP_UGE: {
15259+
auto *ModifiedRHS =
15260+
DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15261+
RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
1510115262
break;
15102-
case CmpInst::ICMP_SGE:
15103-
RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
15263+
}
15264+
case CmpInst::ICMP_SGE: {
15265+
auto *ModifiedRHS =
15266+
DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15267+
RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
1510415268
break;
15269+
}
1510515270
case CmpInst::ICMP_EQ:
1510615271
if (isa<SCEVConstant>(RHS))
1510715272
RewrittenRHS = RHS;
1510815273
break;
1510915274
case CmpInst::ICMP_NE:
1511015275
if (isa<SCEVConstant>(RHS) &&
15111-
cast<SCEVConstant>(RHS)->getValue()->isNullValue())
15112-
RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
15276+
cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15277+
auto *ModifiedRHS = getOne(RHS->getType());
15278+
ModifiedRHS = DividesBy
15279+
? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15280+
: ModifiedRHS;
15281+
RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
15282+
}
1511315283
break;
1511415284
default:
1511515285
break;

llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ define void @test_trip_multiple_4_ugt_5_order_swapped(i32 %num) {
125125
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
126126
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
127127
; CHECK-NEXT: Predicates:
128-
; CHECK: Loop %for.body: Trip multiple is 2
128+
; CHECK: Loop %for.body: Trip multiple is 4
129129
;
130130
entry:
131131
%u = urem i32 %num, 4
@@ -196,7 +196,7 @@ define void @test_trip_multiple_4_sgt_5_order_swapped(i32 %num) {
196196
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
197197
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
198198
; CHECK-NEXT: Predicates:
199-
; CHECK: Loop %for.body: Trip multiple is 2
199+
; CHECK: Loop %for.body: Trip multiple is 4
200200
;
201201
entry:
202202
%u = urem i32 %num, 4
@@ -267,7 +267,7 @@ define void @test_trip_multiple_4_uge_5_order_swapped(i32 %num) {
267267
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
268268
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
269269
; CHECK-NEXT: Predicates:
270-
; CHECK: Loop %for.body: Trip multiple is 1
270+
; CHECK: Loop %for.body: Trip multiple is 4
271271
;
272272
entry:
273273
%u = urem i32 %num, 4
@@ -338,7 +338,7 @@ define void @test_trip_multiple_4_sge_5_order_swapped(i32 %num) {
338338
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
339339
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
340340
; CHECK-NEXT: Predicates:
341-
; CHECK: Loop %for.body: Trip multiple is 1
341+
; CHECK: Loop %for.body: Trip multiple is 4
342342
;
343343
entry:
344344
%u = urem i32 %num, 4
@@ -409,7 +409,7 @@ define void @test_trip_multiple_4_upper_lower_bounds(i32 %num) {
409409
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
410410
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
411411
; CHECK-NEXT: Predicates:
412-
; CHECK: Loop %for.body: Trip multiple is 1
412+
; CHECK: Loop %for.body: Trip multiple is 4
413413
;
414414
entry:
415415
%cmp.1 = icmp uge i32 %num, 5
@@ -446,7 +446,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped1(i32 %num) {
446446
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
447447
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
448448
; CHECK-NEXT: Predicates:
449-
; CHECK: Loop %for.body: Trip multiple is 1
449+
; CHECK: Loop %for.body: Trip multiple is 4
450450
;
451451
entry:
452452
%cmp.1 = icmp uge i32 %num, 5
@@ -483,7 +483,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped2(i32 %num) {
483483
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
484484
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
485485
; CHECK-NEXT: Predicates:
486-
; CHECK: Loop %for.body: Trip multiple is 1
486+
; CHECK: Loop %for.body: Trip multiple is 4
487487
;
488488
entry:
489489
%cmp.1 = icmp uge i32 %num, 5

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,4 +1744,42 @@ TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromMultiDemArray) {
17441744
});
17451745
}
17461746

1747+
TEST_F(ScalarEvolutionsTest, ApplyLoopGuards) {
1748+
LLVMContext C;
1749+
SMDiagnostic Err;
1750+
std::unique_ptr<Module> M = parseAssemblyString(
1751+
"declare void @llvm.assume(i1)\n"
1752+
"define void @test(i32 %num) {\n"
1753+
"entry:\n"
1754+
" %u = urem i32 %num, 4\n"
1755+
" %cmp = icmp eq i32 %u, 0\n"
1756+
" tail call void @llvm.assume(i1 %cmp)\n"
1757+
" %cmp.1 = icmp ugt i32 %num, 0\n"
1758+
" tail call void @llvm.assume(i1 %cmp.1)\n"
1759+
" br label %for.body\n"
1760+
"for.body:\n"
1761+
" %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]\n"
1762+
" %inc = add nuw nsw i32 %i.010, 1\n"
1763+
" %cmp2 = icmp ult i32 %inc, %num\n"
1764+
" br i1 %cmp2, label %for.body, label %exit\n"
1765+
"exit:\n"
1766+
" ret void\n"
1767+
"}\n",
1768+
Err, C);
1769+
1770+
ASSERT_TRUE(M && "Could not parse module?");
1771+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1772+
1773+
runWithSE(*M, "test", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1774+
auto *TCScev = SE.getSCEV(getArgByName(F, "num"));
1775+
auto *ApplyLoopGuardsTC = SE.applyLoopGuards(TCScev, *LI.begin());
1776+
// Assert that the new TC is (4 * ((4 umax %num) /u 4))
1777+
APInt Four(32, 4);
1778+
auto *Constant4 = SE.getConstant(Four);
1779+
auto *Max = SE.getUMaxExpr(TCScev, Constant4);
1780+
auto *Mul = SE.getMulExpr(SE.getUDivExpr(Max, Constant4), Constant4);
1781+
ASSERT_TRUE(Mul == ApplyLoopGuardsTC);
1782+
});
1783+
}
1784+
17471785
} // end namespace llvm

0 commit comments

Comments
 (0)