Skip to content

Commit 5890b30

Browse files
committed
[LAA] Initial support for runtime checks with pointer selects.
Scaffolding support for generating runtime checks for multiple SCEV expressions per pointer. The initial version just adds support for looking through a single pointer select. The more sophisticated logic for analyzing forks is in D108699 Reviewed By: huntergr Differential Revision: https://reviews.llvm.org/D114487
1 parent 76775bd commit 5890b30

File tree

3 files changed

+112
-69
lines changed

3 files changed

+112
-69
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ class RuntimePointerChecking {
406406
/// according to the assumptions that we've made during the analysis.
407407
/// The method might also version the pointer stride according to \p Strides,
408408
/// and add new predicates to \p PSE.
409-
void insert(Loop *Lp, Value *Ptr, Type *AccessTy, bool WritePtr,
410-
unsigned DepSetId, unsigned ASId, const ValueToValueMap &Strides,
409+
void insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, Type *AccessTy,
410+
bool WritePtr, unsigned DepSetId, unsigned ASId,
411411
PredicatedScalarEvolution &PSE);
412412

413413
/// No run-time memory checking is necessary.

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 87 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "llvm/IR/Instructions.h"
4848
#include "llvm/IR/Operator.h"
4949
#include "llvm/IR/PassManager.h"
50+
#include "llvm/IR/PatternMatch.h"
5051
#include "llvm/IR/Type.h"
5152
#include "llvm/IR/Value.h"
5253
#include "llvm/IR/ValueHandle.h"
@@ -65,6 +66,7 @@
6566
#include <vector>
6667

6768
using namespace llvm;
69+
using namespace llvm::PatternMatch;
6870

6971
#define DEBUG_TYPE "loop-accesses"
7072

@@ -188,22 +190,19 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
188190
///
189191
/// There is no conflict when the intervals are disjoint:
190192
/// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
191-
void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, Type *AccessTy,
192-
bool WritePtr, unsigned DepSetId,
193-
unsigned ASId,
194-
const ValueToValueMap &Strides,
193+
void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
194+
Type *AccessTy, bool WritePtr,
195+
unsigned DepSetId, unsigned ASId,
195196
PredicatedScalarEvolution &PSE) {
196-
// Get the stride replaced scev.
197-
const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
198197
ScalarEvolution *SE = PSE.getSE();
199198

200199
const SCEV *ScStart;
201200
const SCEV *ScEnd;
202201

203-
if (SE->isLoopInvariant(Sc, Lp)) {
204-
ScStart = ScEnd = Sc;
202+
if (SE->isLoopInvariant(PtrExpr, Lp)) {
203+
ScStart = ScEnd = PtrExpr;
205204
} else {
206-
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
205+
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr);
207206
assert(AR && "Invalid addrec expression");
208207
const SCEV *Ex = PSE.getBackedgeTakenCount();
209208

@@ -230,7 +229,7 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, Type *AccessTy,
230229
const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr(IdxTy, AccessTy);
231230
ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV);
232231

233-
Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc);
232+
Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr);
234233
}
235234

236235
SmallVector<RuntimePointerCheck, 4>
@@ -370,9 +369,11 @@ void RuntimePointerChecking::groupChecks(
370369

371370
unsigned TotalComparisons = 0;
372371

373-
DenseMap<Value *, unsigned> PositionMap;
374-
for (unsigned Index = 0; Index < Pointers.size(); ++Index)
375-
PositionMap[Pointers[Index].PointerValue] = Index;
372+
DenseMap<Value *, SmallVector<unsigned>> PositionMap;
373+
for (unsigned Index = 0; Index < Pointers.size(); ++Index) {
374+
auto Iter = PositionMap.insert({Pointers[Index].PointerValue, {}});
375+
Iter.first->second.push_back(Index);
376+
}
376377

377378
// We need to keep track of what pointers we've already seen so we
378379
// don't process them twice.
@@ -403,34 +404,35 @@ void RuntimePointerChecking::groupChecks(
403404
auto PointerI = PositionMap.find(MI->getPointer());
404405
assert(PointerI != PositionMap.end() &&
405406
"pointer in equivalence class not found in PositionMap");
406-
unsigned Pointer = PointerI->second;
407-
bool Merged = false;
408-
// Mark this pointer as seen.
409-
Seen.insert(Pointer);
410-
411-
// Go through all the existing sets and see if we can find one
412-
// which can include this pointer.
413-
for (RuntimeCheckingPtrGroup &Group : Groups) {
414-
// Don't perform more than a certain amount of comparisons.
415-
// This should limit the cost of grouping the pointers to something
416-
// reasonable. If we do end up hitting this threshold, the algorithm
417-
// will create separate groups for all remaining pointers.
418-
if (TotalComparisons > MemoryCheckMergeThreshold)
419-
break;
420-
421-
TotalComparisons++;
422-
423-
if (Group.addPointer(Pointer, *this)) {
424-
Merged = true;
425-
break;
407+
for (unsigned Pointer : PointerI->second) {
408+
bool Merged = false;
409+
// Mark this pointer as seen.
410+
Seen.insert(Pointer);
411+
412+
// Go through all the existing sets and see if we can find one
413+
// which can include this pointer.
414+
for (RuntimeCheckingPtrGroup &Group : Groups) {
415+
// Don't perform more than a certain amount of comparisons.
416+
// This should limit the cost of grouping the pointers to something
417+
// reasonable. If we do end up hitting this threshold, the algorithm
418+
// will create separate groups for all remaining pointers.
419+
if (TotalComparisons > MemoryCheckMergeThreshold)
420+
break;
421+
422+
TotalComparisons++;
423+
424+
if (Group.addPointer(Pointer, *this)) {
425+
Merged = true;
426+
break;
427+
}
426428
}
427-
}
428429

429-
if (!Merged)
430-
// We couldn't add this pointer to any existing set or the threshold
431-
// for the number of comparisons has been reached. Create a new group
432-
// to hold the current pointer.
433-
Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this));
430+
if (!Merged)
431+
// We couldn't add this pointer to any existing set or the threshold
432+
// for the number of comparisons has been reached. Create a new group
433+
// to hold the current pointer.
434+
Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this));
435+
}
434436
}
435437

436438
// We've computed the grouped checks for this partition.
@@ -629,11 +631,8 @@ class AccessAnalysis {
629631
/// Check whether a pointer can participate in a runtime bounds check.
630632
/// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr
631633
/// by adding run-time checks (overflow checks) if necessary.
632-
static bool hasComputableBounds(PredicatedScalarEvolution &PSE,
633-
const ValueToValueMap &Strides, Value *Ptr,
634-
Loop *L, bool Assume) {
635-
const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
636-
634+
static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr,
635+
const SCEV *PtrScev, Loop *L, bool Assume) {
637636
// The bounds for loop-invariant pointer is trivial.
638637
if (PSE.getSE()->isLoopInvariant(PtrScev, L))
639638
return true;
@@ -696,34 +695,56 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
696695
bool Assume) {
697696
Value *Ptr = Access.getPointer();
698697

699-
if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume))
700-
return false;
698+
ScalarEvolution &SE = *PSE.getSE();
699+
SmallVector<const SCEV *> TranslatedPtrs;
700+
if (auto *SI = dyn_cast<SelectInst>(Ptr))
701+
TranslatedPtrs = {SE.getSCEV(SI->getOperand(1)),
702+
SE.getSCEV(SI->getOperand(2))};
703+
else
704+
TranslatedPtrs = {replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr)};
701705

702-
// When we run after a failing dependency check we have to make sure
703-
// we don't have wrapping pointers.
704-
if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
705-
auto *Expr = PSE.getSCEV(Ptr);
706-
if (!Assume || !isa<SCEVAddRecExpr>(Expr))
706+
for (const SCEV *PtrExpr : TranslatedPtrs) {
707+
if (!hasComputableBounds(PSE, Ptr, PtrExpr, TheLoop, Assume))
707708
return false;
708-
PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
709+
710+
// When we run after a failing dependency check we have to make sure
711+
// we don't have wrapping pointers.
712+
if (ShouldCheckWrap) {
713+
// Skip wrap checking when translating pointers.
714+
if (TranslatedPtrs.size() > 1)
715+
return false;
716+
717+
if (!isNoWrap(PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
718+
auto *Expr = PSE.getSCEV(Ptr);
719+
if (!Assume || !isa<SCEVAddRecExpr>(Expr))
720+
return false;
721+
PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
722+
}
723+
}
724+
// If there's only one option for Ptr, look it up after bounds and wrap
725+
// checking, because assumptions might have been added to PSE.
726+
if (TranslatedPtrs.size() == 1)
727+
TranslatedPtrs[0] = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
709728
}
710729

711-
// The id of the dependence set.
712-
unsigned DepId;
730+
for (const SCEV *PtrExpr : TranslatedPtrs) {
731+
// The id of the dependence set.
732+
unsigned DepId;
713733

714-
if (isDependencyCheckNeeded()) {
715-
Value *Leader = DepCands.getLeaderValue(Access).getPointer();
716-
unsigned &LeaderId = DepSetId[Leader];
717-
if (!LeaderId)
718-
LeaderId = RunningDepId++;
719-
DepId = LeaderId;
720-
} else
721-
// Each access has its own dependence set.
722-
DepId = RunningDepId++;
734+
if (isDependencyCheckNeeded()) {
735+
Value *Leader = DepCands.getLeaderValue(Access).getPointer();
736+
unsigned &LeaderId = DepSetId[Leader];
737+
if (!LeaderId)
738+
LeaderId = RunningDepId++;
739+
DepId = LeaderId;
740+
} else
741+
// Each access has its own dependence set.
742+
DepId = RunningDepId++;
723743

724-
bool IsWrite = Access.getInt();
725-
RtCheck.insert(TheLoop, Ptr, AccessTy, IsWrite, DepId, ASId, StridesMap, PSE);
726-
LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
744+
bool IsWrite = Access.getInt();
745+
RtCheck.insert(TheLoop, Ptr, PtrExpr, AccessTy, IsWrite, DepId, ASId, PSE);
746+
LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
747+
}
727748

728749
return true;
729750
}

llvm/test/Analysis/LoopAccessAnalysis/forked-pointers.ll

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,32 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
44

55
; CHECK-LABEL: function 'forked_ptrs_simple':
66
; CHECK-NEXT: loop:
7-
; CHECK-NEXT: Report: cannot identify array bounds
7+
; CHECK-NEXT: Memory dependences are safe with run-time checks
88
; CHECK-NEXT: Dependences:
99
; CHECK-NEXT: Run-time memory checks:
10+
; CHECK-NEXT: Check 0:
11+
; CHECK-NEXT: Comparing group ([[G1:.+]]):
12+
; CHECK-NEXT: %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
13+
; CHECK-NEXT: %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
14+
; CHECK-NEXT: Against group ([[G2:.+]]):
15+
; CHECK-NEXT: %select = select i1 %cmp, float* %gep.1, float* %gep.2
16+
; CHECK-NEXT: Check 1:
17+
; CHECK-NEXT: Comparing group ([[G1]]):
18+
; CHECK-NEXT: %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
19+
; CHECK-NEXT: %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
20+
; CHECK-NEXT: Against group ([[G3:.+]]):
21+
; CHECK-NEXT: %select = select i1 %cmp, float* %gep.1, float* %gep.2
1022
; CHECK-NEXT: Grouped accesses:
23+
; CHECK-NEXT: Group [[G1]]
24+
; CHECK-NEXT: (Low: %Dest High: (400 + %Dest))
25+
; CHECK-NEXT: Member: {%Dest,+,4}<nuw><%loop>
26+
; CHECK-NEXT: Member: {%Dest,+,4}<nuw><%loop>
27+
; CHECK-NEXT: Group [[G2]]:
28+
; CHECK-NEXT: (Low: %Base1 High: (400 + %Base1))
29+
; CHECK-NEXT: Member: {%Base1,+,4}<nw><%loop>
30+
; CHECK-NEXT: Group [[G3]]:
31+
; CHECK-NEXT: (Low: %Base2 High: (400 + %Base2))
32+
; CHECK-NEXT: Member: {%Base2,+,4}<nw><%loop>
1133
; CHECK-EMPTY:
1234
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
1335
; CHECK-NEXT: SCEV assumptions:

0 commit comments

Comments
 (0)