47
47
#include " llvm/IR/Instructions.h"
48
48
#include " llvm/IR/Operator.h"
49
49
#include " llvm/IR/PassManager.h"
50
+ #include " llvm/IR/PatternMatch.h"
50
51
#include " llvm/IR/Type.h"
51
52
#include " llvm/IR/Value.h"
52
53
#include " llvm/IR/ValueHandle.h"
65
66
#include < vector>
66
67
67
68
using namespace llvm ;
69
+ using namespace llvm ::PatternMatch;
68
70
69
71
#define DEBUG_TYPE " loop-accesses"
70
72
@@ -188,22 +190,19 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
188
190
// /
189
191
// / There is no conflict when the intervals are disjoint:
190
192
// / NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
191
- void RuntimePointerChecking::insert (Loop *Lp, Value *Ptr, Type *AccessTy,
192
- bool WritePtr, unsigned DepSetId,
193
- unsigned ASId,
194
- const ValueToValueMap &Strides,
193
+ void RuntimePointerChecking::insert (Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
194
+ Type *AccessTy, bool WritePtr,
195
+ unsigned DepSetId, unsigned ASId,
195
196
PredicatedScalarEvolution &PSE) {
196
- // Get the stride replaced scev.
197
- const SCEV *Sc = replaceSymbolicStrideSCEV (PSE, Strides, Ptr);
198
197
ScalarEvolution *SE = PSE.getSE ();
199
198
200
199
const SCEV *ScStart;
201
200
const SCEV *ScEnd;
202
201
203
- if (SE->isLoopInvariant (Sc , Lp)) {
204
- ScStart = ScEnd = Sc ;
202
+ if (SE->isLoopInvariant (PtrExpr , Lp)) {
203
+ ScStart = ScEnd = PtrExpr ;
205
204
} else {
206
- const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc );
205
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr );
207
206
assert (AR && " Invalid addrec expression" );
208
207
const SCEV *Ex = PSE.getBackedgeTakenCount ();
209
208
@@ -230,7 +229,7 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, Type *AccessTy,
230
229
const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr (IdxTy, AccessTy);
231
230
ScEnd = SE->getAddExpr (ScEnd, EltSizeSCEV);
232
231
233
- Pointers.emplace_back (Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc );
232
+ Pointers.emplace_back (Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr );
234
233
}
235
234
236
235
SmallVector<RuntimePointerCheck, 4 >
@@ -370,9 +369,11 @@ void RuntimePointerChecking::groupChecks(
370
369
371
370
unsigned TotalComparisons = 0 ;
372
371
373
- DenseMap<Value *, unsigned > PositionMap;
374
- for (unsigned Index = 0 ; Index < Pointers.size (); ++Index)
375
- PositionMap[Pointers[Index].PointerValue ] = Index;
372
+ DenseMap<Value *, SmallVector<unsigned >> PositionMap;
373
+ for (unsigned Index = 0 ; Index < Pointers.size (); ++Index) {
374
+ auto Iter = PositionMap.insert ({Pointers[Index].PointerValue , {}});
375
+ Iter.first ->second .push_back (Index);
376
+ }
376
377
377
378
// We need to keep track of what pointers we've already seen so we
378
379
// don't process them twice.
@@ -403,34 +404,35 @@ void RuntimePointerChecking::groupChecks(
403
404
auto PointerI = PositionMap.find (MI->getPointer ());
404
405
assert (PointerI != PositionMap.end () &&
405
406
" pointer in equivalence class not found in PositionMap" );
406
- unsigned Pointer = PointerI->second ;
407
- bool Merged = false ;
408
- // Mark this pointer as seen.
409
- Seen.insert (Pointer);
410
-
411
- // Go through all the existing sets and see if we can find one
412
- // which can include this pointer.
413
- for (RuntimeCheckingPtrGroup &Group : Groups) {
414
- // Don't perform more than a certain amount of comparisons.
415
- // This should limit the cost of grouping the pointers to something
416
- // reasonable. If we do end up hitting this threshold, the algorithm
417
- // will create separate groups for all remaining pointers.
418
- if (TotalComparisons > MemoryCheckMergeThreshold)
419
- break ;
420
-
421
- TotalComparisons++;
422
-
423
- if (Group.addPointer (Pointer, *this )) {
424
- Merged = true ;
425
- break ;
407
+ for (unsigned Pointer : PointerI->second ) {
408
+ bool Merged = false ;
409
+ // Mark this pointer as seen.
410
+ Seen.insert (Pointer);
411
+
412
+ // Go through all the existing sets and see if we can find one
413
+ // which can include this pointer.
414
+ for (RuntimeCheckingPtrGroup &Group : Groups) {
415
+ // Don't perform more than a certain amount of comparisons.
416
+ // This should limit the cost of grouping the pointers to something
417
+ // reasonable. If we do end up hitting this threshold, the algorithm
418
+ // will create separate groups for all remaining pointers.
419
+ if (TotalComparisons > MemoryCheckMergeThreshold)
420
+ break ;
421
+
422
+ TotalComparisons++;
423
+
424
+ if (Group.addPointer (Pointer, *this )) {
425
+ Merged = true ;
426
+ break ;
427
+ }
426
428
}
427
- }
428
429
429
- if (!Merged)
430
- // We couldn't add this pointer to any existing set or the threshold
431
- // for the number of comparisons has been reached. Create a new group
432
- // to hold the current pointer.
433
- Groups.push_back (RuntimeCheckingPtrGroup (Pointer, *this ));
430
+ if (!Merged)
431
+ // We couldn't add this pointer to any existing set or the threshold
432
+ // for the number of comparisons has been reached. Create a new group
433
+ // to hold the current pointer.
434
+ Groups.push_back (RuntimeCheckingPtrGroup (Pointer, *this ));
435
+ }
434
436
}
435
437
436
438
// We've computed the grouped checks for this partition.
@@ -629,11 +631,8 @@ class AccessAnalysis {
629
631
// / Check whether a pointer can participate in a runtime bounds check.
630
632
// / If \p Assume, try harder to prove that we can compute the bounds of \p Ptr
631
633
// / by adding run-time checks (overflow checks) if necessary.
632
- static bool hasComputableBounds (PredicatedScalarEvolution &PSE,
633
- const ValueToValueMap &Strides, Value *Ptr,
634
- Loop *L, bool Assume) {
635
- const SCEV *PtrScev = replaceSymbolicStrideSCEV (PSE, Strides, Ptr);
636
-
634
+ static bool hasComputableBounds (PredicatedScalarEvolution &PSE, Value *Ptr,
635
+ const SCEV *PtrScev, Loop *L, bool Assume) {
637
636
// The bounds for loop-invariant pointer is trivial.
638
637
if (PSE.getSE ()->isLoopInvariant (PtrScev, L))
639
638
return true ;
@@ -696,34 +695,56 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
696
695
bool Assume) {
697
696
Value *Ptr = Access.getPointer ();
698
697
699
- if (!hasComputableBounds (PSE, StridesMap, Ptr, TheLoop, Assume))
700
- return false ;
698
+ ScalarEvolution &SE = *PSE.getSE ();
699
+ SmallVector<const SCEV *> TranslatedPtrs;
700
+ if (auto *SI = dyn_cast<SelectInst>(Ptr))
701
+ TranslatedPtrs = {SE.getSCEV (SI->getOperand (1 )),
702
+ SE.getSCEV (SI->getOperand (2 ))};
703
+ else
704
+ TranslatedPtrs = {replaceSymbolicStrideSCEV (PSE, StridesMap, Ptr)};
701
705
702
- // When we run after a failing dependency check we have to make sure
703
- // we don't have wrapping pointers.
704
- if (ShouldCheckWrap && !isNoWrap (PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
705
- auto *Expr = PSE.getSCEV (Ptr);
706
- if (!Assume || !isa<SCEVAddRecExpr>(Expr))
706
+ for (const SCEV *PtrExpr : TranslatedPtrs) {
707
+ if (!hasComputableBounds (PSE, Ptr, PtrExpr, TheLoop, Assume))
707
708
return false ;
708
- PSE.setNoOverflow (Ptr, SCEVWrapPredicate::IncrementNUSW);
709
+
710
+ // When we run after a failing dependency check we have to make sure
711
+ // we don't have wrapping pointers.
712
+ if (ShouldCheckWrap) {
713
+ // Skip wrap checking when translating pointers.
714
+ if (TranslatedPtrs.size () > 1 )
715
+ return false ;
716
+
717
+ if (!isNoWrap (PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
718
+ auto *Expr = PSE.getSCEV (Ptr);
719
+ if (!Assume || !isa<SCEVAddRecExpr>(Expr))
720
+ return false ;
721
+ PSE.setNoOverflow (Ptr, SCEVWrapPredicate::IncrementNUSW);
722
+ }
723
+ }
724
+ // If there's only one option for Ptr, look it up after bounds and wrap
725
+ // checking, because assumptions might have been added to PSE.
726
+ if (TranslatedPtrs.size () == 1 )
727
+ TranslatedPtrs[0 ] = replaceSymbolicStrideSCEV (PSE, StridesMap, Ptr);
709
728
}
710
729
711
- // The id of the dependence set.
712
- unsigned DepId;
730
+ for (const SCEV *PtrExpr : TranslatedPtrs) {
731
+ // The id of the dependence set.
732
+ unsigned DepId;
713
733
714
- if (isDependencyCheckNeeded ()) {
715
- Value *Leader = DepCands.getLeaderValue (Access).getPointer ();
716
- unsigned &LeaderId = DepSetId[Leader];
717
- if (!LeaderId)
718
- LeaderId = RunningDepId++;
719
- DepId = LeaderId;
720
- } else
721
- // Each access has its own dependence set.
722
- DepId = RunningDepId++;
734
+ if (isDependencyCheckNeeded ()) {
735
+ Value *Leader = DepCands.getLeaderValue (Access).getPointer ();
736
+ unsigned &LeaderId = DepSetId[Leader];
737
+ if (!LeaderId)
738
+ LeaderId = RunningDepId++;
739
+ DepId = LeaderId;
740
+ } else
741
+ // Each access has its own dependence set.
742
+ DepId = RunningDepId++;
723
743
724
- bool IsWrite = Access.getInt ();
725
- RtCheck.insert (TheLoop, Ptr, AccessTy, IsWrite, DepId, ASId, StridesMap, PSE);
726
- LLVM_DEBUG (dbgs () << " LAA: Found a runtime check ptr:" << *Ptr << ' \n ' );
744
+ bool IsWrite = Access.getInt ();
745
+ RtCheck.insert (TheLoop, Ptr, PtrExpr, AccessTy, IsWrite, DepId, ASId, PSE);
746
+ LLVM_DEBUG (dbgs () << " LAA: Found a runtime check ptr:" << *Ptr << ' \n ' );
747
+ }
727
748
728
749
return true ;
729
750
}
0 commit comments