@@ -4650,7 +4650,8 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
4650
4650
Ops[0] = removePointerBase(Ops[0]);
4651
4651
// Don't try to transfer nowrap flags for now. We could in some cases
4652
4652
// (for example, if pointer operand of the AddRec is a SCEVUnknown).
4653
- return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4653
+ return getAddRecExprFromMismatchedTypes(Ops, AddRec->getLoop(),
4654
+ SCEV::FlagAnyWrap);
4654
4655
}
4655
4656
if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4656
4657
// The base of an Add is the pointer operand.
@@ -4665,25 +4666,47 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
4665
4666
*PtrOp = removePointerBase(*PtrOp);
4666
4667
// Don't try to transfer nowrap flags for now. We could in some cases
4667
4668
// (for example, if the pointer operand of the Add is a SCEVUnknown).
4668
- return getAddExpr (Ops);
4669
+ return getAddExprFromMismatchedTypes (Ops);
4669
4670
}
4671
+
4672
+ if (auto *Unknown = dyn_cast<SCEVUnknown>(P)) {
4673
+ if (auto *O = dyn_cast<Operator>(Unknown->getValue())) {
4674
+ if (O->getOpcode() == Instruction::IntToPtr) {
4675
+ Value *Op0 = O->getOperand(0);
4676
+ if (isa<ConstantInt>(Op0))
4677
+ return getConstant(dyn_cast<ConstantInt>(Op0));
4678
+ return getSCEV(Op0);
4679
+ }
4680
+ }
4681
+ }
4682
+
4670
4683
// Any other expression must be a pointer base.
4671
4684
return getZero(P->getType());
4672
4685
}
4673
4686
4687
+ static bool isIntToPtr(const SCEV *V) {
4688
+ if (auto *Unknown = dyn_cast<SCEVUnknown>(V))
4689
+ if (auto *Op = dyn_cast<Operator>(Unknown->getValue()))
4690
+ return Op->getOpcode() == Instruction::IntToPtr;
4691
+ return false;
4692
+ }
4693
+
4674
4694
const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4675
4695
SCEV::NoWrapFlags Flags,
4676
4696
unsigned Depth) {
4677
4697
// Fast path: X - X --> 0.
4678
4698
if (LHS == RHS)
4679
4699
return getZero(LHS->getType());
4680
4700
4681
- // If we subtract two pointers with different pointer bases, bail.
4701
+ // If we subtract two pointers except inttoptrs with different pointer bases,
4702
+ // bail.
4682
4703
// Eventually, we're going to add an assertion to getMulExpr that we
4683
4704
// can't multiply by a pointer.
4684
4705
if (RHS->getType()->isPointerTy()) {
4706
+ const SCEV *LBase = getPointerBase(LHS);
4707
+ const SCEV *RBase = getPointerBase(RHS);
4685
4708
if (!LHS->getType()->isPointerTy() ||
4686
- getPointerBase(LHS) != getPointerBase(RHS ))
4709
+ (LBase != RBase && (!isIntToPtr(LBase) || !isIntToPtr(RBase)) ))
4687
4710
return getCouldNotCompute();
4688
4711
LHS = removePointerBase(LHS);
4689
4712
RHS = removePointerBase(RHS);
@@ -4718,7 +4741,8 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4718
4741
// larger scope than intended.
4719
4742
auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4720
4743
4721
- return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4744
+ return getAddExprFromMismatchedTypes(LHS, getNegativeSCEV(RHS, NegFlags),
4745
+ AddFlags, Depth);
4722
4746
}
4723
4747
4724
4748
const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
@@ -4745,6 +4769,18 @@ const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
4745
4769
return getSignExtendExpr(V, Ty, Depth);
4746
4770
}
4747
4771
4772
+ const SCEV *ScalarEvolution::getTruncateOrAnyExtend(const SCEV *V, Type *Ty,
4773
+ unsigned Depth) {
4774
+ Type *SrcTy = V->getType();
4775
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4776
+ "Cannot truncate or any extend with non-integer arguments!");
4777
+ if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4778
+ return V; // No conversion
4779
+ if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4780
+ return getTruncateExpr(V, Ty, Depth);
4781
+ return getAnyExtendExpr(V, Ty);
4782
+ }
4783
+
4748
4784
const SCEV *
4749
4785
ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
4750
4786
Type *SrcTy = V->getType();
@@ -4839,6 +4875,58 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
4839
4875
return getUMinExpr(PromotedOps, Sequential);
4840
4876
}
4841
4877
4878
+ const SCEV *ScalarEvolution::getAddRecExprFromMismatchedTypes(
4879
+ const SmallVectorImpl<const SCEV *> &Ops, const Loop *L,
4880
+ SCEV::NoWrapFlags Flags) {
4881
+ assert(!Ops.empty() && "At least one operand must be!");
4882
+ // Trivial case.
4883
+ if (Ops.size() == 1)
4884
+ return Ops[0];
4885
+
4886
+ // Find the max type first.
4887
+ Type *MaxType = nullptr;
4888
+ for (const auto *S : Ops)
4889
+ if (MaxType)
4890
+ MaxType = getWiderType(MaxType, S->getType());
4891
+ else
4892
+ MaxType = S->getType();
4893
+ assert(MaxType && "Failed to find maximum type!");
4894
+
4895
+ // Extend all ops to max type.
4896
+ SmallVector<const SCEV *, 2> PromotedOps;
4897
+ PromotedOps.reserve(Ops.size());
4898
+ for (const auto *S : Ops)
4899
+ PromotedOps.push_back(getNoopOrAnyExtend(S, MaxType));
4900
+
4901
+ return getAddRecExpr(PromotedOps, L, Flags);
4902
+ }
4903
+
4904
+ const SCEV *ScalarEvolution::getAddExprFromMismatchedTypes(
4905
+ const SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags,
4906
+ unsigned Depth) {
4907
+ assert(!Ops.empty() && "At least one operand must be!");
4908
+ // Trivial case.
4909
+ if (Ops.size() == 1)
4910
+ return Ops[0];
4911
+
4912
+ // Find the max type first.
4913
+ Type *MaxType = nullptr;
4914
+ for (const auto *S : Ops)
4915
+ if (MaxType)
4916
+ MaxType = getWiderType(MaxType, S->getType());
4917
+ else
4918
+ MaxType = S->getType();
4919
+ assert(MaxType && "Failed to find maximum type!");
4920
+
4921
+ // Extend all ops to max type.
4922
+ SmallVector<const SCEV *, 2> PromotedOps;
4923
+ PromotedOps.reserve(Ops.size());
4924
+ for (const auto *S : Ops)
4925
+ PromotedOps.push_back(getNoopOrAnyExtend(S, MaxType));
4926
+
4927
+ return getAddExpr(PromotedOps, Flags, Depth);
4928
+ }
4929
+
4842
4930
const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
4843
4931
// A pointer operand may evaluate to a nonpointer expression, such as null.
4844
4932
if (!V->getType()->isPointerTy())
0 commit comments