Skip to content

Commit 9c4baf5

Browse files
committed
[ScalarEvolution] Strictly enforce pointer/int type rules.
Rules: 1. SCEVUnknown is a pointer if and only if the LLVM IR value is a pointer. 2. SCEVPtrToInt is never a pointer. 3. If any other SCEV expression has no pointer operands, the result is an integer. 4. If a SCEVAddExpr has exactly one pointer operand, the result is a pointer. 5. If a SCEVAddRecExpr's first operand is a pointer, and it has no other pointer operands, the result is a pointer. 6. If every operand of a SCEVMinMaxExpr is a pointer, the result is a pointer. 7. Otherwise, the SCEV expression is invalid. I'm not sure how useful rule 6 is in practice. If we exclude it, we can guarantee that ScalarEvolution::getPointerBase always returns a SCEVUnknown, which might be a helpful property. Anyway, I'll leave that for a followup. This is basically mop-up at this point; all the changes with significant functional effects have landed. Some of the remaining changes could be split off, but I don't see much point. Differential Revision: https://reviews.llvm.org/D105510
1 parent 8e9216f commit 9c4baf5

File tree

4 files changed

+51
-21
lines changed

4 files changed

+51
-21
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
11921192
"This is not a truncating conversion!");
11931193
assert(isSCEVable(Ty) &&
11941194
"This is not a conversion to a SCEVable type!");
1195+
assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
11951196
Ty = getEffectiveSCEVType(Ty);
11961197

11971198
FoldingSetNodeID ID;
@@ -1581,6 +1582,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
15811582
"This is not an extending conversion!");
15821583
assert(isSCEVable(Ty) &&
15831584
"This is not a conversion to a SCEVable type!");
1585+
assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
15841586
Ty = getEffectiveSCEVType(Ty);
15851587

15861588
// Fold if the operand is constant.
@@ -1883,6 +1885,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
18831885
"This is not an extending conversion!");
18841886
assert(isSCEVable(Ty) &&
18851887
"This is not a conversion to a SCEVable type!");
1888+
assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
18861889
Ty = getEffectiveSCEVType(Ty);
18871890

18881891
// Fold if the operand is constant.
@@ -2410,6 +2413,9 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
24102413
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
24112414
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
24122415
"SCEVAddExpr operand types don't match!");
2416+
unsigned NumPtrs = count_if(
2417+
Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2418+
assert(NumPtrs <= 1 && "add has at most one pointer operand");
24132419
#endif
24142420

24152421
// Sort by complexity, this groups all similar expression types together.
@@ -2645,12 +2651,16 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
26452651
Ops.clear();
26462652
if (AccumulatedConstant != 0)
26472653
Ops.push_back(getConstant(AccumulatedConstant));
2648-
for (auto &MulOp : MulOpLists)
2649-
if (MulOp.first != 0)
2654+
for (auto &MulOp : MulOpLists) {
2655+
if (MulOp.first == 1) {
2656+
Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2657+
} else if (MulOp.first != 0) {
26502658
Ops.push_back(getMulExpr(
26512659
getConstant(MulOp.first),
26522660
getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
26532661
SCEV::FlagAnyWrap, Depth + 1));
2662+
}
2663+
}
26542664
if (Ops.empty())
26552665
return getZero(Ty);
26562666
if (Ops.size() == 1)
@@ -2969,9 +2979,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
29692979
assert(!Ops.empty() && "Cannot get empty mul!");
29702980
if (Ops.size() == 1) return Ops[0];
29712981
#ifndef NDEBUG
2972-
Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2982+
Type *ETy = Ops[0]->getType();
2983+
assert(!ETy->isPointerTy());
29732984
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2974-
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2985+
assert(Ops[i]->getType() == ETy &&
29752986
"SCEVMulExpr operand types don't match!");
29762987
#endif
29772988

@@ -3256,8 +3267,9 @@ const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
32563267
/// possible.
32573268
const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
32583269
const SCEV *RHS) {
3259-
assert(getEffectiveSCEVType(LHS->getType()) ==
3260-
getEffectiveSCEVType(RHS->getType()) &&
3270+
assert(!LHS->getType()->isPointerTy() &&
3271+
"SCEVUDivExpr operand can't be pointer!");
3272+
assert(LHS->getType() == RHS->getType() &&
32613273
"SCEVUDivExpr operand types don't match!");
32623274

32633275
FoldingSetNodeID ID;
@@ -3506,9 +3518,11 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
35063518
if (Operands.size() == 1) return Operands[0];
35073519
#ifndef NDEBUG
35083520
Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3509-
for (unsigned i = 1, e = Operands.size(); i != e; ++i)
3521+
for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
35103522
assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
35113523
"SCEVAddRecExpr operand types don't match!");
3524+
assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3525+
}
35123526
for (unsigned i = 0, e = Operands.size(); i != e; ++i)
35133527
assert(isLoopInvariant(Operands[i], L) &&
35143528
"SCEVAddRecExpr operand is not loop-invariant!");
@@ -3662,9 +3676,13 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
36623676
if (Ops.size() == 1) return Ops[0];
36633677
#ifndef NDEBUG
36643678
Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3665-
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3679+
for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
36663680
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
36673681
"Operand types don't match!");
3682+
assert(Ops[0]->getType()->isPointerTy() ==
3683+
Ops[i]->getType()->isPointerTy() &&
3684+
"min/max should be consistently pointerish");
3685+
}
36683686
#endif
36693687

36703688
bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
@@ -10579,6 +10597,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
1057910597
}
1058010598
}
1058110599

10600+
if (LHS->getType()->isPointerTy())
10601+
return false;
1058210602
if (CmpInst::isSigned(Pred)) {
1058310603
LHS = getSignExtendExpr(LHS, FoundLHS->getType());
1058410604
RHS = getSignExtendExpr(RHS, FoundLHS->getType());
@@ -10588,6 +10608,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
1058810608
}
1058910609
} else if (getTypeSizeInBits(LHS->getType()) >
1059010610
getTypeSizeInBits(FoundLHS->getType())) {
10611+
if (FoundLHS->getType()->isPointerTy())
10612+
return false;
1059110613
if (CmpInst::isSigned(FoundPred)) {
1059210614
FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
1059310615
FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());

llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,11 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
682682
const APInt &RA = RC->getAPInt();
683683
// Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do
684684
// some folding.
685-
if (RA.isAllOnesValue())
685+
if (RA.isAllOnesValue()) {
686+
if (LHS->getType()->isPointerTy())
687+
return nullptr;
686688
return SE.getMulExpr(LHS, RC);
689+
}
687690
// Handle x /s 1 as x.
688691
if (RA == 1)
689692
return LHS;
@@ -4063,7 +4066,8 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
40634066
// Determine the integer type for the base formula.
40644067
Type *DstTy = Base.getType();
40654068
if (!DstTy) return;
4066-
DstTy = SE.getEffectiveSCEVType(DstTy);
4069+
if (DstTy->isPointerTy())
4070+
return;
40674071

40684072
for (Type *SrcTy : Types) {
40694073
if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) {
@@ -5301,7 +5305,7 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF,
53015305
if (F.BaseGV) {
53025306
// Flush the operand list to suppress SCEVExpander hoisting.
53035307
if (!Ops.empty()) {
5304-
Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty);
5308+
Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), IntTy);
53055309
Ops.clear();
53065310
Ops.push_back(SE.getUnknown(FullV));
53075311
}

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,10 @@ static bool canBeCheaplyTransformed(ScalarEvolution &SE,
11471147
const SCEVAddRecExpr *Phi,
11481148
const SCEVAddRecExpr *Requested,
11491149
bool &InvertStep) {
1150+
// We can't transform to match a pointer PHI.
1151+
if (Phi->getType()->isPointerTy())
1152+
return false;
1153+
11501154
Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType());
11511155
Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType());
11521156

@@ -1165,8 +1169,7 @@ static bool canBeCheaplyTransformed(ScalarEvolution &SE,
11651169
}
11661170

11671171
// Check whether inverting will help: {R,+,-1} == R - {0,+,1}.
1168-
if (SE.getAddExpr(Requested->getStart(),
1169-
SE.getNegativeSCEV(Requested)) == Phi) {
1172+
if (SE.getMinusSCEV(Requested->getStart(), Requested) == Phi) {
11701173
InvertStep = true;
11711174
return true;
11721175
}
@@ -1577,8 +1580,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
15771580
// Rewrite an AddRec in terms of the canonical induction variable, if
15781581
// its type is more narrow.
15791582
if (CanonicalIV &&
1580-
SE.getTypeSizeInBits(CanonicalIV->getType()) >
1581-
SE.getTypeSizeInBits(Ty)) {
1583+
SE.getTypeSizeInBits(CanonicalIV->getType()) > SE.getTypeSizeInBits(Ty) &&
1584+
!S->getType()->isPointerTy()) {
15821585
SmallVector<const SCEV *, 4> NewOps(S->getNumOperands());
15831586
for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i)
15841587
NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType());

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
9696
const SCEV *S1 = SE.getSCEV(V1);
9797
const SCEV *S2 = SE.getSCEV(V2);
9898

99-
const SCEV *P0 = SE.getAddExpr(S0, S0);
100-
const SCEV *P1 = SE.getAddExpr(S1, S1);
101-
const SCEV *P2 = SE.getAddExpr(S2, S2);
99+
const SCEV *P0 = SE.getAddExpr(S0, SE.getConstant(S0->getType(), 2));
100+
const SCEV *P1 = SE.getAddExpr(S1, SE.getConstant(S0->getType(), 2));
101+
const SCEV *P2 = SE.getAddExpr(S2, SE.getConstant(S0->getType(), 2));
102102

103-
const SCEVMulExpr *M0 = cast<SCEVMulExpr>(P0);
104-
const SCEVMulExpr *M1 = cast<SCEVMulExpr>(P1);
105-
const SCEVMulExpr *M2 = cast<SCEVMulExpr>(P2);
103+
auto *M0 = cast<SCEVAddExpr>(P0);
104+
auto *M1 = cast<SCEVAddExpr>(P1);
105+
auto *M2 = cast<SCEVAddExpr>(P2);
106106

107107
EXPECT_EQ(cast<SCEVConstant>(M0->getOperand(0))->getValue()->getZExtValue(),
108108
2u);
@@ -707,6 +707,7 @@ TEST_F(ScalarEvolutionsTest, SCEVZeroExtendExpr) {
707707
ReturnInst::Create(Context, nullptr, EndBB);
708708
ScalarEvolution SE = buildSE(*F);
709709
const SCEV *S = SE.getSCEV(Accum);
710+
S = SE.getLosslessPtrToIntExpr(S);
710711
Type *I128Ty = Type::getInt128Ty(Context);
711712
SE.getZeroExtendExpr(S, I128Ty);
712713
}

0 commit comments

Comments
 (0)