Skip to content

Commit 632fe58

Browse files
committed
[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied.
A SCEVWrapPredicate A implies B, if * they have the same flag, * both steps are positive and * B's start and step are ULE/SLE (for NSUW/NSSW) than A's. See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future.
1 parent 3ad2399 commit 632fe58

File tree

4 files changed

+84
-39
lines changed

4 files changed

+84
-39
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
241241
virtual bool isAlwaysTrue() const = 0;
242242

243243
/// Returns true if this predicate implies \p N.
244-
virtual bool implies(const SCEVPredicate *N) const = 0;
244+
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;
245245

246246
/// Prints a textual representation of this predicate with an indentation of
247247
/// \p Depth.
@@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
286286
const SCEV *LHS, const SCEV *RHS);
287287

288288
/// Implementation of the SCEVPredicate interface
289-
bool implies(const SCEVPredicate *N) const override;
289+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
290290
void print(raw_ostream &OS, unsigned Depth = 0) const override;
291291
bool isAlwaysTrue() const override;
292292

@@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {
393393

394394
/// Implementation of the SCEVPredicate interface
395395
const SCEVAddRecExpr *getExpr() const;
396-
bool implies(const SCEVPredicate *N) const override;
396+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
397397
void print(raw_ostream &OS, unsigned Depth = 0) const override;
398398
bool isAlwaysTrue() const override;
399399

@@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
418418
SmallVector<const SCEVPredicate *, 16> Preds;
419419

420420
/// Adds a predicate to this union.
421-
void add(const SCEVPredicate *N);
421+
void add(const SCEVPredicate *N, ScalarEvolution &SE);
422422

423423
public:
424-
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
424+
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
425+
ScalarEvolution &SE);
425426

426427
ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }
427428

428429
/// Implementation of the SCEVPredicate interface
429430
bool isAlwaysTrue() const override;
430-
bool implies(const SCEVPredicate *N) const override;
431+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
431432
void print(raw_ostream &OS, unsigned Depth) const override;
432433

433434
/// We estimate the complexity of a union predicate as the size number of

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
57065706
return true;
57075707

57085708
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709-
if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5710-
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5709+
if (Expr1 != Expr2 &&
5710+
!Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711+
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
57115712
return false;
57125713
return true;
57135714
};
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1485714858
bool addOverflowAssumption(const SCEVPredicate *P) {
1485814859
if (!NewPreds) {
1485914860
// Check if we've already made this assumption.
14860-
return Pred && Pred->implies(P);
14861+
return Pred && Pred->implies(P, SE);
1486114862
}
1486214863
NewPreds->push_back(P);
1486314864
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
1493814939
assert(LHS != RHS && "LHS and RHS are the same SCEV");
1493914940
}
1494014941

14941-
bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14942+
bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14943+
ScalarEvolution &SE) const {
1494214944
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
1494314945

1494414946
if (!Op)
@@ -14968,10 +14970,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
1496814970

1496914971
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
1497014972

14971-
bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14973+
bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14974+
ScalarEvolution &SE) const {
1497214975
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14976+
if (!Op)
14977+
return false;
14978+
14979+
if (setFlags(Flags, Op->Flags) != Flags)
14980+
return false;
14981+
14982+
if (Op->AR == AR)
14983+
return true;
14984+
14985+
if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14986+
Flags != SCEVWrapPredicate::IncrementNUSW)
14987+
return false;
1497314988

14974-
return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14989+
bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14990+
const SCEV *Step = AR->getStepRecurrence(SE);
14991+
const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14992+
14993+
// If both steps are positive, this implies N, if N's start and step are
14994+
// ULE/SLE (for NSUW/NSSW) than this'.
14995+
if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
14996+
const SCEV *OpStart = Op->AR->getStart();
14997+
const SCEV *Start = AR->getStart();
14998+
if (SE.getTypeSizeInBits(Step->getType()) >
14999+
SE.getTypeSizeInBits(OpStep->getType())) {
15000+
OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
15001+
} else {
15002+
Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
15003+
: SE.getNoopOrSignExtend(Step, OpStep->getType());
15004+
}
15005+
if (SE.getTypeSizeInBits(Start->getType()) >
15006+
SE.getTypeSizeInBits(OpStart->getType())) {
15007+
OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
15008+
: SE.getSignExtendExpr(OpStart, Start->getType());
15009+
} else {
15010+
Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
15011+
: SE.getNoopOrSignExtend(Start, OpStart->getType());
15012+
}
15013+
15014+
CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
15015+
return SE.isKnownPredicate(Pred, OpStep, Step) &&
15016+
SE.isKnownPredicate(Pred, OpStart, Start);
15017+
}
15018+
return false;
1497515019
}
1497615020

1497715021
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,48 +15059,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
1501515059
}
1501615060

1501715061
/// Union predicates don't get cached so create a dummy set ID for it.
15018-
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
15019-
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15062+
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15063+
ScalarEvolution &SE)
15064+
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
1502015065
for (const auto *P : Preds)
15021-
add(P);
15066+
add(P, SE);
1502215067
}
1502315068

1502415069
bool SCEVUnionPredicate::isAlwaysTrue() const {
1502515070
return all_of(Preds,
1502615071
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
1502715072
}
1502815073

15029-
bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15074+
bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15075+
ScalarEvolution &SE) const {
1503015076
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15031-
return all_of(Set->Preds,
15032-
[this](const SCEVPredicate *I) { return this->implies(I); });
15077+
return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15078+
return this->implies(I, SE);
15079+
});
1503315080

1503415081
return any_of(Preds,
15035-
[N](const SCEVPredicate *I) { return I->implies(N); });
15082+
[N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
1503615083
}
1503715084

1503815085
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
1503915086
for (const auto *Pred : Preds)
1504015087
Pred->print(OS, Depth);
1504115088
}
1504215089

15043-
void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15090+
void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
1504415091
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
1504515092
for (const auto *Pred : Set->Preds)
15046-
add(Pred);
15093+
add(Pred, SE);
1504715094
return;
1504815095
}
1504915096

1505015097
// Only add predicate if it is not already implied by this union predicate.
15051-
if (!implies(N))
15098+
if (!implies(N, SE))
1505215099
Preds.push_back(N);
1505315100
}
1505415101

1505515102
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
1505615103
Loop &L)
1505715104
: SE(SE), L(L) {
1505815105
SmallVector<const SCEVPredicate*, 4> Empty;
15059-
Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15106+
Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
1506015107
}
1506115108

1506215109
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15167,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
1512015167
}
1512115168

1512215169
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15123-
if (Preds->implies(&Pred))
15170+
if (Preds->implies(&Pred, SE))
1512415171
return;
1512515172

1512615173
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
1512715174
NewPreds.push_back(&Pred);
15128-
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15175+
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
1512915176
updateGeneration();
1513015177
}
1513115178

@@ -15192,9 +15239,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1519215239

1519315240
PredicatedScalarEvolution::PredicatedScalarEvolution(
1519415241
const PredicatedScalarEvolution &Init)
15195-
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15196-
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15197-
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15242+
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15243+
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15244+
SE)),
15245+
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
1519815246
for (auto I : Init.FlagsMap)
1519915247
FlagsMap.insert(I);
1520015248
}

llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
2929
; CHECK-NEXT: Run-time memory checks:
3030
; CHECK-NEXT: Check 0:
3131
; CHECK-NEXT: Comparing group
32-
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
33-
; CHECK-NEXT: Against group
3432
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
33+
; CHECK-NEXT: Against group
34+
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
3535
; CHECK-NEXT: Grouped accesses:
3636
; CHECK-NEXT: Group
37-
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
38-
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
39-
; CHECK-NEXT: Group
4037
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
4138
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
39+
; CHECK-NEXT: Group
40+
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
41+
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
4242
; CHECK: Non vectorizable stores to invariant address were not found in loop.
4343
; CHECK-NEXT: SCEV assumptions:
4444
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
45-
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
4645
; CHECK: Expressions re-written:
4746
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
4847
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
@@ -85,7 +84,6 @@ exit:
8584
; CHECK: Memory dependences are safe
8685
; CHECK: SCEV assumptions:
8786
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
88-
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
8987
define void @test2(i64 %x, ptr %a) {
9088
entry:
9189
br label %for.body

llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
55

6-
; FIXME: {0,+,3} implies {0,+,2}.
6+
; {0,+,3} [nssw] implies {0,+,2} [nssw]
77
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
88
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
99
; CHECK-NEXT: loop:
@@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
2626
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
2727
; CHECK-NEXT: SCEV assumptions:
2828
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
29-
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
3029
; CHECK-EMPTY:
3130
; CHECK-NEXT: Expressions re-written:
3231
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
@@ -59,7 +58,7 @@ exit:
5958
ret void
6059
}
6160

62-
; FIXME: {2,+,2} implies {0,+,2}.
61+
; {2,+,2} [nssw] implies {0,+,2} [nssw].
6362
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
6463
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
6564
; CHECK-NEXT: loop:
@@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
8281
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
8382
; CHECK-NEXT: SCEV assumptions:
8483
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
85-
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
8684
; CHECK-EMPTY:
8785
; CHECK-NEXT: Expressions re-written:
8886
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:

0 commit comments

Comments
 (0)