Skip to content

Commit 7bfcf93

Browse files
authored
[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. (#118184)
A SCEVWrapPredicate A implies B, if * they have the same flag, * both steps are positive and * B's start and step are ULE/SLE (for NSUW/NSSW) than A's. See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future. PR: #118184
1 parent c539014 commit 7bfcf93

File tree

4 files changed

+72
-39
lines changed

4 files changed

+72
-39
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
241241
virtual bool isAlwaysTrue() const = 0;
242242

243243
/// Returns true if this predicate implies \p N.
244-
virtual bool implies(const SCEVPredicate *N) const = 0;
244+
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;
245245

246246
/// Prints a textual representation of this predicate with an indentation of
247247
/// \p Depth.
@@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
286286
const SCEV *LHS, const SCEV *RHS);
287287

288288
/// Implementation of the SCEVPredicate interface
289-
bool implies(const SCEVPredicate *N) const override;
289+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
290290
void print(raw_ostream &OS, unsigned Depth = 0) const override;
291291
bool isAlwaysTrue() const override;
292292

@@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {
393393

394394
/// Implementation of the SCEVPredicate interface
395395
const SCEVAddRecExpr *getExpr() const;
396-
bool implies(const SCEVPredicate *N) const override;
396+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
397397
void print(raw_ostream &OS, unsigned Depth = 0) const override;
398398
bool isAlwaysTrue() const override;
399399

@@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
418418
SmallVector<const SCEVPredicate *, 16> Preds;
419419

420420
/// Adds a predicate to this union.
421-
void add(const SCEVPredicate *N);
421+
void add(const SCEVPredicate *N, ScalarEvolution &SE);
422422

423423
public:
424-
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
424+
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
425+
ScalarEvolution &SE);
425426

426427
ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }
427428

428429
/// Implementation of the SCEVPredicate interface
429430
bool isAlwaysTrue() const override;
430-
bool implies(const SCEVPredicate *N) const override;
431+
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
431432
void print(raw_ostream &OS, unsigned Depth) const override;
432433

433434
/// We estimate the complexity of a union predicate as the size number of

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
57065706
return true;
57075707

57085708
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709-
if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5710-
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5709+
if (Expr1 != Expr2 &&
5710+
!Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711+
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
57115712
return false;
57125713
return true;
57135714
};
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1485714858
bool addOverflowAssumption(const SCEVPredicate *P) {
1485814859
if (!NewPreds) {
1485914860
// Check if we've already made this assumption.
14860-
return Pred && Pred->implies(P);
14861+
return Pred && Pred->implies(P, SE);
1486114862
}
1486214863
NewPreds->push_back(P);
1486314864
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
1493814939
assert(LHS != RHS && "LHS and RHS are the same SCEV");
1493914940
}
1494014941

14941-
bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14942+
bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14943+
ScalarEvolution &SE) const {
1494214944
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
1494314945

1494414946
if (!Op)
@@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
1496814970

1496914971
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
1497014972

14971-
bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14973+
bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14974+
ScalarEvolution &SE) const {
1497214975
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14976+
if (!Op || setFlags(Flags, Op->Flags) != Flags)
14977+
return false;
14978+
14979+
if (Op->AR == AR)
14980+
return true;
14981+
14982+
if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14983+
Flags != SCEVWrapPredicate::IncrementNUSW)
14984+
return false;
1497314985

14974-
return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14986+
const SCEV *Step = AR->getStepRecurrence(SE);
14987+
const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14988+
if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14989+
return false;
14990+
14991+
// If both steps are positive, this implies N, if N's start and step are
14992+
// ULE/SLE (for NSUW/NSSW) than this'.
14993+
Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14994+
Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14995+
OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14996+
14997+
bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14998+
const SCEV *OpStart = Op->AR->getStart();
14999+
const SCEV *Start = AR->getStart();
15000+
OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15001+
: SE.getNoopOrSignExtend(OpStart, WiderTy);
15002+
Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15003+
: SE.getNoopOrSignExtend(Start, WiderTy);
15004+
CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
15005+
return SE.isKnownPredicate(Pred, OpStep, Step) &&
15006+
SE.isKnownPredicate(Pred, OpStart, Start);
1497515007
}
1497615008

1497715009
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,48 +15047,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
1501515047
}
1501615048

1501715049
/// Union predicates don't get cached so create a dummy set ID for it.
15018-
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
15019-
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15050+
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15051+
ScalarEvolution &SE)
15052+
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
1502015053
for (const auto *P : Preds)
15021-
add(P);
15054+
add(P, SE);
1502215055
}
1502315056

1502415057
bool SCEVUnionPredicate::isAlwaysTrue() const {
1502515058
return all_of(Preds,
1502615059
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
1502715060
}
1502815061

15029-
bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15062+
bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15063+
ScalarEvolution &SE) const {
1503015064
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15031-
return all_of(Set->Preds,
15032-
[this](const SCEVPredicate *I) { return this->implies(I); });
15065+
return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15066+
return this->implies(I, SE);
15067+
});
1503315068

1503415069
return any_of(Preds,
15035-
[N](const SCEVPredicate *I) { return I->implies(N); });
15070+
[N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
1503615071
}
1503715072

1503815073
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
1503915074
for (const auto *Pred : Preds)
1504015075
Pred->print(OS, Depth);
1504115076
}
1504215077

15043-
void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15078+
void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
1504415079
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
1504515080
for (const auto *Pred : Set->Preds)
15046-
add(Pred);
15081+
add(Pred, SE);
1504715082
return;
1504815083
}
1504915084

1505015085
// Only add predicate if it is not already implied by this union predicate.
15051-
if (!implies(N))
15086+
if (!implies(N, SE))
1505215087
Preds.push_back(N);
1505315088
}
1505415089

1505515090
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
1505615091
Loop &L)
1505715092
: SE(SE), L(L) {
1505815093
SmallVector<const SCEVPredicate*, 4> Empty;
15059-
Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15094+
Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
1506015095
}
1506115096

1506215097
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
1512015155
}
1512115156

1512215157
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15123-
if (Preds->implies(&Pred))
15158+
if (Preds->implies(&Pred, SE))
1512415159
return;
1512515160

1512615161
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
1512715162
NewPreds.push_back(&Pred);
15128-
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15163+
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
1512915164
updateGeneration();
1513015165
}
1513115166

@@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1519215227

1519315228
PredicatedScalarEvolution::PredicatedScalarEvolution(
1519415229
const PredicatedScalarEvolution &Init)
15195-
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15196-
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15197-
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15230+
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15231+
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15232+
SE)),
15233+
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
1519815234
for (auto I : Init.FlagsMap)
1519915235
FlagsMap.insert(I);
1520015236
}

llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
2929
; CHECK-NEXT: Run-time memory checks:
3030
; CHECK-NEXT: Check 0:
3131
; CHECK-NEXT: Comparing group
32-
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
33-
; CHECK-NEXT: Against group
3432
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
33+
; CHECK-NEXT: Against group
34+
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
3535
; CHECK-NEXT: Grouped accesses:
3636
; CHECK-NEXT: Group
37-
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
38-
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
39-
; CHECK-NEXT: Group
4037
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
4138
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
39+
; CHECK-NEXT: Group
40+
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
41+
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
4242
; CHECK: Non vectorizable stores to invariant address were not found in loop.
4343
; CHECK-NEXT: SCEV assumptions:
4444
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
45-
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
4645
; CHECK: Expressions re-written:
4746
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
4847
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
@@ -85,7 +84,6 @@ exit:
8584
; CHECK: Memory dependences are safe
8685
; CHECK: SCEV assumptions:
8786
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
88-
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
8987
define void @test2(i64 %x, ptr %a) {
9088
entry:
9189
br label %for.body

llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
55

6-
; FIXME: {0,+,3} implies {0,+,2}.
6+
; {0,+,3} [nssw] implies {0,+,2} [nssw]
77
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
88
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
99
; CHECK-NEXT: loop:
@@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
2626
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
2727
; CHECK-NEXT: SCEV assumptions:
2828
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
29-
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
3029
; CHECK-EMPTY:
3130
; CHECK-NEXT: Expressions re-written:
3231
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
@@ -59,7 +58,7 @@ exit:
5958
ret void
6059
}
6160

62-
; FIXME: {2,+,2} implies {0,+,2}.
61+
; {2,+,2} [nssw] implies {0,+,2} [nssw].
6362
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
6463
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
6564
; CHECK-NEXT: loop:
@@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
8281
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
8382
; CHECK-NEXT: SCEV assumptions:
8483
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
85-
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
8684
; CHECK-EMPTY:
8785
; CHECK-NEXT: Expressions re-written:
8886
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:

0 commit comments

Comments
 (0)