Skip to content

Commit 8fef35a

Browse files
committed
[LAA] Use SCEVUse to add extra NUW flags to pointer bounds.
Use SCEVUse to add a NUW flag to the upper bound of an accessed pointer. We must already have proved that the pointers do not wrap, as otherwise we could not use them for runtime check computations. By adding the use-specific NUW flag, we can detect cases where SCEV can prove that the compared pointers must overlap, so the runtime checks will always be false. In that case, there is no point in vectorizing with runtime checks. Note that this depends c2895cd27fbf200d1da056bc66d77eeb62690bf0, which could be submitted separately if desired; without the current change, I don't think it triggers in practice though.
1 parent 29065eb commit 8fef35a

File tree

5 files changed

+96
-113
lines changed

5 files changed

+96
-113
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,10 @@ struct RuntimeCheckingPtrGroup {
372372

373373
/// The SCEV expression which represents the upper bound of all the
374374
/// pointers in this group.
375-
const SCEV *High;
375+
SCEVUse High;
376376
/// The SCEV expression which represents the lower bound of all the
377377
/// pointers in this group.
378-
const SCEV *Low;
378+
SCEVUse Low;
379379
/// Indices of all the pointers that constitute this grouping.
380380
SmallVector<unsigned, 2> Members;
381381
/// Address space of the involved pointers.
@@ -413,10 +413,10 @@ class RuntimePointerChecking {
413413
TrackingVH<Value> PointerValue;
414414
/// Holds the smallest byte address accessed by the pointer throughout all
415415
/// iterations of the loop.
416-
const SCEV *Start;
416+
SCEVUse Start;
417417
/// Holds the largest byte address accessed by the pointer throughout all
418418
/// iterations of the loop, plus 1.
419-
const SCEV *End;
419+
SCEVUse End;
420420
/// Holds the information if this pointer is used for writing to memory.
421421
bool IsWritePtr;
422422
/// Holds the id of the set of pointers that could be dependent because of a
@@ -429,7 +429,7 @@ class RuntimePointerChecking {
429429
/// True if the pointer expressions needs to be frozen after expansion.
430430
bool NeedsFreeze;
431431

432-
PointerInfo(Value *PointerValue, const SCEV *Start, const SCEV *End,
432+
PointerInfo(Value *PointerValue, SCEVUse Start, SCEVUse End,
433433
bool IsWritePtr, unsigned DependencySetId, unsigned AliasSetId,
434434
const SCEV *Expr, bool NeedsFreeze)
435435
: PointerValue(PointerValue), Start(Start), End(End),
@@ -443,8 +443,10 @@ class RuntimePointerChecking {
443443
/// Reset the state of the pointer runtime information.
444444
void reset() {
445445
Need = false;
446+
AlwaysFalse = false;
446447
Pointers.clear();
447448
Checks.clear();
449+
CheckingGroups.clear();
448450
}
449451

450452
/// Insert a pointer and calculate the start and end SCEVs.
@@ -501,6 +503,8 @@ class RuntimePointerChecking {
501503
/// This flag indicates if we need to add the runtime check.
502504
bool Need = false;
503505

506+
bool AlwaysFalse = false;
507+
504508
/// Information about the pointers that may require checking.
505509
SmallVector<PointerInfo, 2> Pointers;
506510

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
210210
bool NeedsFreeze) {
211211
ScalarEvolution *SE = PSE.getSE();
212212

213-
const SCEV *ScStart;
214-
const SCEV *ScEnd;
213+
SCEVUse ScStart;
214+
SCEVUse ScEnd;
215215

216216
if (SE->isLoopInvariant(PtrExpr, Lp)) {
217217
ScStart = ScEnd = PtrExpr;
@@ -223,6 +223,8 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
223223
ScStart = AR->getStart();
224224
ScEnd = AR->evaluateAtIteration(Ex, *SE);
225225
const SCEV *Step = AR->getStepRecurrence(*SE);
226+
if (auto *Comm = dyn_cast<SCEVCommutativeExpr>(ScEnd))
227+
ScEnd = SCEVUse(ScEnd, 2);
226228

227229
// For expressions with negative step, the upper bound is ScStart and the
228230
// lower bound is ScEnd.
@@ -244,7 +246,10 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
244246
auto &DL = Lp->getHeader()->getModule()->getDataLayout();
245247
Type *IdxTy = DL.getIndexType(Ptr->getType());
246248
const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr(IdxTy, AccessTy);
247-
ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV);
249+
// TODO: this computes one-past-the-end. ScEnd + EltSizeSCEV - 1 is the last
250+
// accessed byte. Not entirely sure if one-past-the-end must also not wrap? If
251+
// it does, could compute and use last accessed byte instead.
252+
ScEnd = SCEVUse(SE->getAddExpr(ScEnd, EltSizeSCEV), 2);
248253

249254
Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr,
250255
NeedsFreeze);
@@ -379,6 +384,11 @@ SmallVector<RuntimePointerCheck, 4> RuntimePointerChecking::generateChecks() {
379384
if (needsChecking(CGI, CGJ)) {
380385
tryToCreateDiffCheck(CGI, CGJ);
381386
Checks.push_back(std::make_pair(&CGI, &CGJ));
387+
if (SE->isKnownPredicate(CmpInst::ICMP_UGT, CGI.High, CGJ.Low) &&
388+
SE->isKnownPredicate(CmpInst::ICMP_ULE, CGI.Low, CGJ.High)) {
389+
AlwaysFalse = true;
390+
return {};
391+
}
382392
}
383393
}
384394
}
@@ -635,8 +645,7 @@ void RuntimePointerChecking::print(raw_ostream &OS, unsigned Depth) const {
635645
const auto &CG = CheckingGroups[I];
636646

637647
OS.indent(Depth + 2) << "Group " << &CG << ":\n";
638-
OS.indent(Depth + 4) << "(Low: " << *CG.Low << " High: " << *CG.High
639-
<< ")\n";
648+
OS.indent(Depth + 4) << "(Low: " << CG.Low << " High: " << CG.High << ")\n";
640649
for (unsigned J = 0; J < CG.Members.size(); ++J) {
641650
OS.indent(Depth + 6) << "Member: " << *Pointers[CG.Members[J]].Expr
642651
<< "\n";
@@ -1274,6 +1283,7 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
12741283
// If we can do run-time checks, but there are no checks, no runtime checks
12751284
// are needed. This can happen when all pointers point to the same underlying
12761285
// object for example.
1286+
CanDoRT &= !RtCheck.AlwaysFalse;
12771287
RtCheck.Need = CanDoRT ? RtCheck.getNumberOfChecks() != 0 : MayNeedRTCheck;
12781288

12791289
bool CanDoRTIfNeeded = !RtCheck.Need || CanDoRT;

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11886,6 +11886,8 @@ bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
1188611886
L = AE->getOperand(0);
1188711887
R = AE->getOperand(1);
1188811888
Flags = AE->getNoWrapFlags();
11889+
Flags = setFlags(AE->getNoWrapFlags(),
11890+
static_cast<SCEV::NoWrapFlags>(Expr.getInt()));
1188911891
return true;
1189011892
}
1189111893

0 commit comments

Comments
 (0)