Skip to content

Commit 81bcaa7

Browse files
committed
[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied.
A SCEVWrapPredicate A implies B, if * they have the same flag, * both steps are positive and * B's start and step are ULE/SLE (for NSUW/NSSW) than A's. See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future.
1 parent 2e30df7 commit 81bcaa7

File tree

4 files changed

+84
-39
lines changed

4 files changed

+84
-39
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
241241
virtual bool isAlwaysTrue() const = 0;
242242

243243
/// Returns true if this predicate implies \p N.
244-
virtual bool implies(const SCEVPredicate *N) const = 0;
244+
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;
245245

246246
/// Prints a textual representation of this predicate with an indentation of
247247
/// \p Depth.
@@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
286286
const SCEV *LHS, const SCEV *RHS);
287287

288288
/// Implementation of the SCEVPredicate interface
289-
bool implies(const SCEVPredicate *N) const override;
289+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
290290
void print(raw_ostream &OS, unsigned Depth = 0) const override;
291291
bool isAlwaysTrue() const override;
292292

@@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {
393393

394394
/// Implementation of the SCEVPredicate interface
395395
const SCEVAddRecExpr *getExpr() const;
396-
bool implies(const SCEVPredicate *N) const override;
396+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
397397
void print(raw_ostream &OS, unsigned Depth = 0) const override;
398398
bool isAlwaysTrue() const override;
399399

@@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
418418
SmallVector<const SCEVPredicate *, 16> Preds;
419419

420420
/// Adds a predicate to this union.
421-
void add(const SCEVPredicate *N);
421+
void add(const SCEVPredicate *N, ScalarEvolution &SE);
422422

423423
public:
424-
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
424+
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
425+
ScalarEvolution &SE);
425426

426427
ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }
427428

428429
/// Implementation of the SCEVPredicate interface
429430
bool isAlwaysTrue() const override;
430-
bool implies(const SCEVPredicate *N) const override;
431+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
431432
void print(raw_ostream &OS, unsigned Depth) const override;
432433

433434
/// We estimate the complexity of a union predicate as the size number of

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5725,8 +5725,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
57255725
return true;
57265726

57275727
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5728-
if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5729-
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5728+
if (Expr1 != Expr2 &&
5729+
!Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5730+
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
57305731
return false;
57315732
return true;
57325733
};
@@ -14823,7 +14824,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1482314824
bool addOverflowAssumption(const SCEVPredicate *P) {
1482414825
if (!NewPreds) {
1482514826
// Check if we've already made this assumption.
14826-
return Pred && Pred->implies(P);
14827+
return Pred && Pred->implies(P, SE);
1482714828
}
1482814829
NewPreds->push_back(P);
1482914830
return true;
@@ -14904,7 +14905,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
1490414905
assert(LHS != RHS && "LHS and RHS are the same SCEV");
1490514906
}
1490614907

14907-
bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14908+
bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14909+
ScalarEvolution &SE) const {
1490814910
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
1490914911

1491014912
if (!Op)
@@ -14934,10 +14936,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
1493414936

1493514937
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
1493614938

14937-
bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14939+
bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14940+
ScalarEvolution &SE) const {
1493814941
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14942+
if (!Op)
14943+
return false;
14944+
14945+
if (setFlags(Flags, Op->Flags) != Flags)
14946+
return false;
14947+
14948+
if (Op->AR == AR)
14949+
return true;
14950+
14951+
if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14952+
Flags != SCEVWrapPredicate::IncrementNUSW)
14953+
return false;
1493914954

14940-
return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14955+
bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14956+
const SCEV *Step = AR->getStepRecurrence(SE);
14957+
const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14958+
14959+
// If both steps are positive, this implies N, if N's start and step are
14960+
// ULE/SLE (for NSUW/NSSW) than this'.
14961+
if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
14962+
const SCEV *OpStart = Op->AR->getStart();
14963+
const SCEV *Start = AR->getStart();
14964+
if (SE.getTypeSizeInBits(Step->getType()) >
14965+
SE.getTypeSizeInBits(OpStep->getType())) {
14966+
OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
14967+
} else {
14968+
Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
14969+
: SE.getNoopOrSignExtend(Step, OpStep->getType());
14970+
}
14971+
if (SE.getTypeSizeInBits(Start->getType()) >
14972+
SE.getTypeSizeInBits(OpStart->getType())) {
14973+
OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
14974+
: SE.getSignExtendExpr(OpStart, Start->getType());
14975+
} else {
14976+
Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
14977+
: SE.getNoopOrSignExtend(Start, OpStart->getType());
14978+
}
14979+
14980+
CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
14981+
return SE.isKnownPredicate(Pred, OpStep, Step) &&
14982+
SE.isKnownPredicate(Pred, OpStart, Start);
14983+
}
14984+
return false;
1494114985
}
1494214986

1494314987
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -14981,48 +15025,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
1498115025
}
1498215026

1498315027
/// Union predicates don't get cached so create a dummy set ID for it.
14984-
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
14985-
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15028+
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15029+
ScalarEvolution &SE)
15030+
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
1498615031
for (const auto *P : Preds)
14987-
add(P);
15032+
add(P, SE);
1498815033
}
1498915034

1499015035
bool SCEVUnionPredicate::isAlwaysTrue() const {
1499115036
return all_of(Preds,
1499215037
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
1499315038
}
1499415039

14995-
bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15040+
bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15041+
ScalarEvolution &SE) const {
1499615042
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14997-
return all_of(Set->Preds,
14998-
[this](const SCEVPredicate *I) { return this->implies(I); });
15043+
return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15044+
return this->implies(I, SE);
15045+
});
1499915046

1500015047
return any_of(Preds,
15001-
[N](const SCEVPredicate *I) { return I->implies(N); });
15048+
[N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
1500215049
}
1500315050

1500415051
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
1500515052
for (const auto *Pred : Preds)
1500615053
Pred->print(OS, Depth);
1500715054
}
1500815055

15009-
void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15056+
void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
1501015057
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
1501115058
for (const auto *Pred : Set->Preds)
15012-
add(Pred);
15059+
add(Pred, SE);
1501315060
return;
1501415061
}
1501515062

1501615063
// Only add predicate if it is not already implied by this union predicate.
15017-
if (!implies(N))
15064+
if (!implies(N, SE))
1501815065
Preds.push_back(N);
1501915066
}
1502015067

1502115068
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
1502215069
Loop &L)
1502315070
: SE(SE), L(L) {
1502415071
SmallVector<const SCEVPredicate*, 4> Empty;
15025-
Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15072+
Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
1502615073
}
1502715074

1502815075
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15086,12 +15133,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
1508615133
}
1508715134

1508815135
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15089-
if (Preds->implies(&Pred))
15136+
if (Preds->implies(&Pred, SE))
1509015137
return;
1509115138

1509215139
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
1509315140
NewPreds.push_back(&Pred);
15094-
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15141+
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
1509515142
updateGeneration();
1509615143
}
1509715144

@@ -15158,9 +15205,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1515815205

1515915206
PredicatedScalarEvolution::PredicatedScalarEvolution(
1516015207
const PredicatedScalarEvolution &Init)
15161-
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15162-
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15163-
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15208+
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15209+
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15210+
SE)),
15211+
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
1516415212
for (auto I : Init.FlagsMap)
1516515213
FlagsMap.insert(I);
1516615214
}

llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
2929
; CHECK-NEXT: Run-time memory checks:
3030
; CHECK-NEXT: Check 0:
3131
; CHECK-NEXT: Comparing group
32-
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
33-
; CHECK-NEXT: Against group
3432
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
33+
; CHECK-NEXT: Against group
34+
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
3535
; CHECK-NEXT: Grouped accesses:
3636
; CHECK-NEXT: Group
37-
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
38-
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
39-
; CHECK-NEXT: Group
4037
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
4138
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
39+
; CHECK-NEXT: Group
40+
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
41+
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
4242
; CHECK: Non vectorizable stores to invariant address were not found in loop.
4343
; CHECK-NEXT: SCEV assumptions:
4444
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
45-
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
4645
; CHECK: Expressions re-written:
4746
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
4847
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
@@ -85,7 +84,6 @@ exit:
8584
; CHECK: Memory dependences are safe
8685
; CHECK: SCEV assumptions:
8786
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
88-
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
8987
define void @test2(i64 %x, ptr %a) {
9088
entry:
9189
br label %for.body

llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
55

6-
; FIXME: {0,+,3} implies {0,+,2}.
6+
; {0,+,3} [nssw] implies {0,+,2} [nssw]
77
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
88
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
99
; CHECK-NEXT: loop:
@@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
2626
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
2727
; CHECK-NEXT: SCEV assumptions:
2828
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
29-
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
3029
; CHECK-EMPTY:
3130
; CHECK-NEXT: Expressions re-written:
3231
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
@@ -59,7 +58,7 @@ exit:
5958
ret void
6059
}
6160

62-
; FIXME: {2,+,2} implies {0,+,2}.
61+
; {2,+,2} [nssw] implies {0,+,2} [nssw].
6362
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
6463
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
6564
; CHECK-NEXT: loop:
@@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
8281
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
8382
; CHECK-NEXT: SCEV assumptions:
8483
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
85-
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
8684
; CHECK-EMPTY:
8785
; CHECK-NEXT: Expressions re-written:
8886
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:

0 commit comments

Comments
 (0)