47
47
#include " llvm/IR/Instructions.h"
48
48
#include " llvm/IR/Operator.h"
49
49
#include " llvm/IR/PassManager.h"
50
- #include " llvm/IR/PatternMatch.h"
51
50
#include " llvm/IR/Type.h"
52
51
#include " llvm/IR/Value.h"
53
52
#include " llvm/IR/ValueHandle.h"
66
65
#include < vector>
67
66
68
67
using namespace llvm ;
69
- using namespace llvm ::PatternMatch;
70
68
71
69
#define DEBUG_TYPE " loop-accesses"
72
70
@@ -190,19 +188,22 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
190
188
// /
191
189
// / There is no conflict when the intervals are disjoint:
192
190
// / NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
193
- void RuntimePointerChecking::insert (Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
194
- Type *AccessTy, bool WritePtr,
195
- unsigned DepSetId, unsigned ASId,
191
+ void RuntimePointerChecking::insert (Loop *Lp, Value *Ptr, Type *AccessTy,
192
+ bool WritePtr, unsigned DepSetId,
193
+ unsigned ASId,
194
+ const ValueToValueMap &Strides,
196
195
PredicatedScalarEvolution &PSE) {
196
+ // Get the stride replaced scev.
197
+ const SCEV *Sc = replaceSymbolicStrideSCEV (PSE, Strides, Ptr);
197
198
ScalarEvolution *SE = PSE.getSE ();
198
199
199
200
const SCEV *ScStart;
200
201
const SCEV *ScEnd;
201
202
202
- if (SE->isLoopInvariant (PtrExpr , Lp)) {
203
- ScStart = ScEnd = PtrExpr ;
203
+ if (SE->isLoopInvariant (Sc , Lp)) {
204
+ ScStart = ScEnd = Sc ;
204
205
} else {
205
- const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr );
206
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc );
206
207
assert (AR && " Invalid addrec expression" );
207
208
const SCEV *Ex = PSE.getBackedgeTakenCount ();
208
209
@@ -229,7 +230,7 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
229
230
const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr (IdxTy, AccessTy);
230
231
ScEnd = SE->getAddExpr (ScEnd, EltSizeSCEV);
231
232
232
- Pointers.emplace_back (Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr );
233
+ Pointers.emplace_back (Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc );
233
234
}
234
235
235
236
void RuntimePointerChecking::tryToCreateDiffCheck (
@@ -455,11 +456,9 @@ void RuntimePointerChecking::groupChecks(
455
456
456
457
unsigned TotalComparisons = 0 ;
457
458
458
- DenseMap<Value *, SmallVector<unsigned >> PositionMap;
459
- for (unsigned Index = 0 ; Index < Pointers.size (); ++Index) {
460
- auto Iter = PositionMap.insert ({Pointers[Index].PointerValue , {}});
461
- Iter.first ->second .push_back (Index);
462
- }
459
+ DenseMap<Value *, unsigned > PositionMap;
460
+ for (unsigned Index = 0 ; Index < Pointers.size (); ++Index)
461
+ PositionMap[Pointers[Index].PointerValue ] = Index;
463
462
464
463
// We need to keep track of what pointers we've already seen so we
465
464
// don't process them twice.
@@ -490,35 +489,34 @@ void RuntimePointerChecking::groupChecks(
490
489
auto PointerI = PositionMap.find (MI->getPointer ());
491
490
assert (PointerI != PositionMap.end () &&
492
491
" pointer in equivalence class not found in PositionMap" );
493
- for (unsigned Pointer : PointerI->second ) {
494
- bool Merged = false ;
495
- // Mark this pointer as seen.
496
- Seen.insert (Pointer);
497
-
498
- // Go through all the existing sets and see if we can find one
499
- // which can include this pointer.
500
- for (RuntimeCheckingPtrGroup &Group : Groups) {
501
- // Don't perform more than a certain amount of comparisons.
502
- // This should limit the cost of grouping the pointers to something
503
- // reasonable. If we do end up hitting this threshold, the algorithm
504
- // will create separate groups for all remaining pointers.
505
- if (TotalComparisons > MemoryCheckMergeThreshold)
506
- break ;
507
-
508
- TotalComparisons++;
509
-
510
- if (Group.addPointer (Pointer, *this )) {
511
- Merged = true ;
512
- break ;
513
- }
492
+ unsigned Pointer = PointerI->second ;
493
+ bool Merged = false ;
494
+ // Mark this pointer as seen.
495
+ Seen.insert (Pointer);
496
+
497
+ // Go through all the existing sets and see if we can find one
498
+ // which can include this pointer.
499
+ for (RuntimeCheckingPtrGroup &Group : Groups) {
500
+ // Don't perform more than a certain amount of comparisons.
501
+ // This should limit the cost of grouping the pointers to something
502
+ // reasonable. If we do end up hitting this threshold, the algorithm
503
+ // will create separate groups for all remaining pointers.
504
+ if (TotalComparisons > MemoryCheckMergeThreshold)
505
+ break ;
506
+
507
+ TotalComparisons++;
508
+
509
+ if (Group.addPointer (Pointer, *this )) {
510
+ Merged = true ;
511
+ break ;
514
512
}
515
-
516
- if (!Merged)
517
- // We couldn't add this pointer to any existing set or the threshold
518
- // for the number of comparisons has been reached. Create a new group
519
- // to hold the current pointer.
520
- Groups.push_back (RuntimeCheckingPtrGroup (Pointer, *this ));
521
513
}
514
+
515
+ if (!Merged)
516
+ // We couldn't add this pointer to any existing set or the threshold
517
+ // for the number of comparisons has been reached. Create a new group
518
+ // to hold the current pointer.
519
+ Groups.push_back (RuntimeCheckingPtrGroup (Pointer, *this ));
522
520
}
523
521
524
522
// We've computed the grouped checks for this partition.
@@ -717,8 +715,11 @@ class AccessAnalysis {
717
715
// / Check whether a pointer can participate in a runtime bounds check.
718
716
// / If \p Assume, try harder to prove that we can compute the bounds of \p Ptr
719
717
// / by adding run-time checks (overflow checks) if necessary.
720
- static bool hasComputableBounds (PredicatedScalarEvolution &PSE, Value *Ptr,
721
- const SCEV *PtrScev, Loop *L, bool Assume) {
718
+ static bool hasComputableBounds (PredicatedScalarEvolution &PSE,
719
+ const ValueToValueMap &Strides, Value *Ptr,
720
+ Loop *L, bool Assume) {
721
+ const SCEV *PtrScev = replaceSymbolicStrideSCEV (PSE, Strides, Ptr);
722
+
722
723
// The bounds for loop-invariant pointer is trivial.
723
724
if (PSE.getSE ()->isLoopInvariant (PtrScev, L))
724
725
return true ;
@@ -781,56 +782,34 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
781
782
bool Assume) {
782
783
Value *Ptr = Access.getPointer ();
783
784
784
- ScalarEvolution &SE = *PSE.getSE ();
785
- SmallVector<const SCEV *> TranslatedPtrs;
786
- if (auto *SI = dyn_cast<SelectInst>(Ptr))
787
- TranslatedPtrs = {SE.getSCEV (SI->getOperand (1 )),
788
- SE.getSCEV (SI->getOperand (2 ))};
789
- else
790
- TranslatedPtrs = {replaceSymbolicStrideSCEV (PSE, StridesMap, Ptr)};
785
+ if (!hasComputableBounds (PSE, StridesMap, Ptr, TheLoop, Assume))
786
+ return false ;
791
787
792
- for (const SCEV *PtrExpr : TranslatedPtrs) {
793
- if (!hasComputableBounds (PSE, Ptr, PtrExpr, TheLoop, Assume))
788
+ // When we run after a failing dependency check we have to make sure
789
+ // we don't have wrapping pointers.
790
+ if (ShouldCheckWrap && !isNoWrap (PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
791
+ auto *Expr = PSE.getSCEV (Ptr);
792
+ if (!Assume || !isa<SCEVAddRecExpr>(Expr))
794
793
return false ;
794
+ PSE.setNoOverflow (Ptr, SCEVWrapPredicate::IncrementNUSW);
795
+ }
795
796
796
- // When we run after a failing dependency check we have to make sure
797
- // we don't have wrapping pointers.
798
- if (ShouldCheckWrap) {
799
- // Skip wrap checking when translating pointers.
800
- if (TranslatedPtrs.size () > 1 )
801
- return false ;
797
+ // The id of the dependence set.
798
+ unsigned DepId;
802
799
803
- if (!isNoWrap (PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
804
- auto *Expr = PSE.getSCEV (Ptr);
805
- if (!Assume || !isa<SCEVAddRecExpr>(Expr))
806
- return false ;
807
- PSE.setNoOverflow (Ptr, SCEVWrapPredicate::IncrementNUSW);
808
- }
809
- }
810
- // If there's only one option for Ptr, look it up after bounds and wrap
811
- // checking, because assumptions might have been added to PSE.
812
- if (TranslatedPtrs.size () == 1 )
813
- TranslatedPtrs[0 ] = replaceSymbolicStrideSCEV (PSE, StridesMap, Ptr);
814
- }
815
-
816
- for (const SCEV *PtrExpr : TranslatedPtrs) {
817
- // The id of the dependence set.
818
- unsigned DepId;
819
-
820
- if (isDependencyCheckNeeded ()) {
821
- Value *Leader = DepCands.getLeaderValue (Access).getPointer ();
822
- unsigned &LeaderId = DepSetId[Leader];
823
- if (!LeaderId)
824
- LeaderId = RunningDepId++;
825
- DepId = LeaderId;
826
- } else
827
- // Each access has its own dependence set.
828
- DepId = RunningDepId++;
800
+ if (isDependencyCheckNeeded ()) {
801
+ Value *Leader = DepCands.getLeaderValue (Access).getPointer ();
802
+ unsigned &LeaderId = DepSetId[Leader];
803
+ if (!LeaderId)
804
+ LeaderId = RunningDepId++;
805
+ DepId = LeaderId;
806
+ } else
807
+ // Each access has its own dependence set.
808
+ DepId = RunningDepId++;
829
809
830
- bool IsWrite = Access.getInt ();
831
- RtCheck.insert (TheLoop, Ptr, PtrExpr, AccessTy, IsWrite, DepId, ASId, PSE);
832
- LLVM_DEBUG (dbgs () << " LAA: Found a runtime check ptr:" << *Ptr << ' \n ' );
833
- }
810
+ bool IsWrite = Access.getInt ();
811
+ RtCheck.insert (TheLoop, Ptr, AccessTy, IsWrite, DepId, ASId, StridesMap, PSE);
812
+ LLVM_DEBUG (dbgs () << " LAA: Found a runtime check ptr:" << *Ptr << ' \n ' );
834
813
835
814
return true ;
836
815
}
0 commit comments