Skip to content

Commit 60fd999

Browse files
committed
[SCEVAA] Allowing to subtract two inttoptrs with different pointer bases
1 parent 8f21294 commit 60fd999

File tree

4 files changed

+158
-13
lines changed

4 files changed

+158
-13
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,11 @@ class ScalarEvolution {
721721
const SCEV *getTruncateOrSignExtend(const SCEV *V, Type *Ty,
722722
unsigned Depth = 0);
723723

724+
/// Return a SCEV corresponding to a conversion of the input value to the
725+
/// specified type. If the type must be extended, it is any extended.
726+
const SCEV *getTruncateOrAnyExtend(const SCEV *V, Type *Ty,
727+
unsigned Depth = 0);
728+
724729
/// Return a SCEV corresponding to a conversion of the input value to the
725730
/// specified type. If the type must be extended, it is zero extended. The
726731
/// conversion must not be narrowing.
@@ -754,6 +759,26 @@ class ScalarEvolution {
754759
const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
755760
bool Sequential = false);
756761

762+
/// Promote the operands to the wider of the types using any-extension, and
763+
/// then perform a addrec operation with them.
764+
const SCEV *
765+
getAddRecExprFromMismatchedTypes(const SmallVectorImpl<const SCEV *> &Ops,
766+
const Loop *L, SCEV::NoWrapFlags Flags);
767+
768+
/// Promote the operands to the wider of the types using any-extension, and
769+
/// then perform a add operation with them.
770+
const SCEV *
771+
getAddExprFromMismatchedTypes(const SmallVectorImpl<const SCEV *> &Ops,
772+
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
773+
unsigned Depth = 0);
774+
const SCEV *
775+
getAddExprFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS,
776+
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
777+
unsigned Depth = 0) {
778+
SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
779+
return getAddExprFromMismatchedTypes(Ops, Flags, Depth);
780+
}
781+
757782
/// Transitively follow the chain of pointer-type operands until reaching a
758783
/// SCEV that does not have a single pointer operand. This returns a
759784
/// SCEVUnknown pointer for well-formed pointer-type expressions, but corner

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4650,7 +4650,8 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
46504650
Ops[0] = removePointerBase(Ops[0]);
46514651
// Don't try to transfer nowrap flags for now. We could in some cases
46524652
// (for example, if pointer operand of the AddRec is a SCEVUnknown).
4653-
return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4653+
return getAddRecExprFromMismatchedTypes(Ops, AddRec->getLoop(),
4654+
SCEV::FlagAnyWrap);
46544655
}
46554656
if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
46564657
// The base of an Add is the pointer operand.
@@ -4665,25 +4666,47 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
46654666
*PtrOp = removePointerBase(*PtrOp);
46664667
// Don't try to transfer nowrap flags for now. We could in some cases
46674668
// (for example, if the pointer operand of the Add is a SCEVUnknown).
4668-
return getAddExpr(Ops);
4669+
return getAddExprFromMismatchedTypes(Ops);
46694670
}
4671+
4672+
if (auto *Unknown = dyn_cast<SCEVUnknown>(P)) {
4673+
if (auto *O = dyn_cast<Operator>(Unknown->getValue())) {
4674+
if (O->getOpcode() == Instruction::IntToPtr) {
4675+
Value *Op0 = O->getOperand(0);
4676+
if (isa<ConstantInt>(Op0))
4677+
return getConstant(dyn_cast<ConstantInt>(Op0));
4678+
return getSCEV(Op0);
4679+
}
4680+
}
4681+
}
4682+
46704683
// Any other expression must be a pointer base.
46714684
return getZero(P->getType());
46724685
}
46734686

4687+
static bool isIntToPtr(const SCEV *V) {
4688+
if (auto *Unknown = dyn_cast<SCEVUnknown>(V))
4689+
if (auto *Op = dyn_cast<Operator>(Unknown->getValue()))
4690+
return Op->getOpcode() == Instruction::IntToPtr;
4691+
return false;
4692+
}
4693+
46744694
const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
46754695
SCEV::NoWrapFlags Flags,
46764696
unsigned Depth) {
46774697
// Fast path: X - X --> 0.
46784698
if (LHS == RHS)
46794699
return getZero(LHS->getType());
46804700

4681-
// If we subtract two pointers with different pointer bases, bail.
4701+
// If we subtract two pointers except inttoptrs with different pointer bases,
4702+
// bail.
46824703
// Eventually, we're going to add an assertion to getMulExpr that we
46834704
// can't multiply by a pointer.
46844705
if (RHS->getType()->isPointerTy()) {
4706+
const SCEV *LBase = getPointerBase(LHS);
4707+
const SCEV *RBase = getPointerBase(RHS);
46854708
if (!LHS->getType()->isPointerTy() ||
4686-
getPointerBase(LHS) != getPointerBase(RHS))
4709+
(LBase != RBase && (!isIntToPtr(LBase) || !isIntToPtr(RBase))))
46874710
return getCouldNotCompute();
46884711
LHS = removePointerBase(LHS);
46894712
RHS = removePointerBase(RHS);
@@ -4718,7 +4741,8 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
47184741
// larger scope than intended.
47194742
auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
47204743

4721-
return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4744+
return getAddExprFromMismatchedTypes(LHS, getNegativeSCEV(RHS, NegFlags),
4745+
AddFlags, Depth);
47224746
}
47234747

47244748
const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
@@ -4745,6 +4769,18 @@ const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
47454769
return getSignExtendExpr(V, Ty, Depth);
47464770
}
47474771

4772+
const SCEV *ScalarEvolution::getTruncateOrAnyExtend(const SCEV *V, Type *Ty,
4773+
unsigned Depth) {
4774+
Type *SrcTy = V->getType();
4775+
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4776+
"Cannot truncate or any extend with non-integer arguments!");
4777+
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4778+
return V; // No conversion
4779+
if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4780+
return getTruncateExpr(V, Ty, Depth);
4781+
return getAnyExtendExpr(V, Ty);
4782+
}
4783+
47484784
const SCEV *
47494785
ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
47504786
Type *SrcTy = V->getType();
@@ -4839,6 +4875,58 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
48394875
return getUMinExpr(PromotedOps, Sequential);
48404876
}
48414877

4878+
const SCEV *ScalarEvolution::getAddRecExprFromMismatchedTypes(
4879+
const SmallVectorImpl<const SCEV *> &Ops, const Loop *L,
4880+
SCEV::NoWrapFlags Flags) {
4881+
assert(!Ops.empty() && "At least one operand must be!");
4882+
// Trivial case.
4883+
if (Ops.size() == 1)
4884+
return Ops[0];
4885+
4886+
// Find the max type first.
4887+
Type *MaxType = nullptr;
4888+
for (const auto *S : Ops)
4889+
if (MaxType)
4890+
MaxType = getWiderType(MaxType, S->getType());
4891+
else
4892+
MaxType = S->getType();
4893+
assert(MaxType && "Failed to find maximum type!");
4894+
4895+
// Extend all ops to max type.
4896+
SmallVector<const SCEV *, 2> PromotedOps;
4897+
PromotedOps.reserve(Ops.size());
4898+
for (const auto *S : Ops)
4899+
PromotedOps.push_back(getNoopOrAnyExtend(S, MaxType));
4900+
4901+
return getAddRecExpr(PromotedOps, L, Flags);
4902+
}
4903+
4904+
const SCEV *ScalarEvolution::getAddExprFromMismatchedTypes(
4905+
const SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags,
4906+
unsigned Depth) {
4907+
assert(!Ops.empty() && "At least one operand must be!");
4908+
// Trivial case.
4909+
if (Ops.size() == 1)
4910+
return Ops[0];
4911+
4912+
// Find the max type first.
4913+
Type *MaxType = nullptr;
4914+
for (const auto *S : Ops)
4915+
if (MaxType)
4916+
MaxType = getWiderType(MaxType, S->getType());
4917+
else
4918+
MaxType = S->getType();
4919+
assert(MaxType && "Failed to find maximum type!");
4920+
4921+
// Extend all ops to max type.
4922+
SmallVector<const SCEV *, 2> PromotedOps;
4923+
PromotedOps.reserve(Ops.size());
4924+
for (const auto *S : Ops)
4925+
PromotedOps.push_back(getNoopOrAnyExtend(S, MaxType));
4926+
4927+
return getAddExpr(PromotedOps, Flags, Depth);
4928+
}
4929+
48424930
const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
48434931
// A pointer operand may evaluate to a nonpointer expression, such as null.
48444932
if (!V->getType()->isPointerTy())

llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA,
6767
// Test whether the difference is known to be great enough that memory of
6868
// the given sizes don't overlap. This assumes that ASizeInt and BSizeInt
6969
// are non-zero, which is special-cased above.
70-
if (!isa<SCEVCouldNotCompute>(BA) &&
71-
ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) &&
72-
(-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax()))
73-
return AliasResult::NoAlias;
70+
if (!isa<SCEVCouldNotCompute>(BA)) {
71+
if (SE.isSCEVable(BA->getType()))
72+
BA = SE.getTruncateOrAnyExtend(BA, AS->getType());
73+
if (ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) &&
74+
(-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax()))
75+
return AliasResult::NoAlias;
76+
}
7477

7578
// Folding the subtraction while preserving range information can be tricky
7679
// (because of INT_MIN, etc.); if the prior test failed, swap AS and BS
@@ -82,10 +85,13 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA,
8285
// Test whether the difference is known to be great enough that memory of
8386
// the given sizes don't overlap. This assumes that ASizeInt and BSizeInt
8487
// are non-zero, which is special-cased above.
85-
if (!isa<SCEVCouldNotCompute>(AB) &&
86-
BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) &&
87-
(-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax()))
88-
return AliasResult::NoAlias;
88+
if (!isa<SCEVCouldNotCompute>(AB)) {
89+
if (SE.isSCEVable(AB->getType()))
90+
AB = SE.getTruncateOrAnyExtend(AB, AS->getType());
91+
if (BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) &&
92+
(-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax()))
93+
return AliasResult::NoAlias;
94+
}
8995
}
9096

9197
// If ScalarEvolution can find an underlying object, form a new query.

llvm/test/Analysis/ScalarEvolution/scev-aa.ll

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,29 @@ for.latch:
340340
for.end:
341341
ret void
342342
}
343+
344+
; CHECK-LABEL: Function: test_different_pointer_bases_of_inttoptr: 2 pointers, 0 call sites
345+
; CHECK: NoAlias: <16 x i8>* %tmp5, <16 x i8>* %tmp7
346+
347+
define void @test_different_pointer_bases_of_inttoptr() {
348+
entry:
349+
br label %for.body
350+
351+
for.body:
352+
%tmp = phi i32 [ %next, %for.body ], [ 1, %entry ]
353+
%tmp1 = shl nsw i32 %tmp, 1
354+
%tmp2 = add nuw nsw i32 %tmp1, %tmp1
355+
%tmp3 = mul nsw i32 %tmp2, 1408
356+
%tmp4 = add nsw i32 %tmp3, 1408
357+
%tmp5 = getelementptr inbounds i8, ptr inttoptr (i32 1024 to ptr), i32 %tmp1
358+
%tmp6 = load <16 x i8>, ptr %tmp5, align 1
359+
%tmp7 = getelementptr inbounds i8, ptr inttoptr (i32 4096 to ptr), i32 %tmp4
360+
store <16 x i8> %tmp6, ptr %tmp7, align 1
361+
362+
%next = add i32 %tmp, 2
363+
%exitcond = icmp slt i32 %next, 10
364+
br i1 %exitcond, label %for.body, label %for.end
365+
366+
for.end:
367+
ret void
368+
}

0 commit comments

Comments
 (0)