-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[InstCombine] Support nested GEPs in OptimizePointerDifference #142958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Nikita Popov (nikic) ChangesCurrently OptimizePointerDifference() only handles single GEPs with a common base, not GEP chains. This patch generalizes the support to nested GEPs with a common base. Finding the common base is a bit annoying because we want to stop as soon as possible and not recurse into common GEP prefixes. This helps avoids regressions from #137297. Full diff: https://github.com/llvm/llvm-project/pull/142958.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index a9ac5ff9b9c89..eb5537717327b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2068,71 +2068,119 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return nullptr;
}
+struct CommonBase {
+ /// Common base pointer.
+ Value *Ptr = nullptr;
+ /// LHS GEPs until common base.
+ SmallVector<GEPOperator *> LHSGEPs;
+ /// RHS GEPs until common base.
+ SmallVector<GEPOperator *> RHSGEPs;
+ /// LHS GEP NoWrapFlags until common base.
+ GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
+ /// RHS GEP NoWrapFlags until common base.
+ GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
+};
+
+static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
+ CommonBase Base;
+
+ if (LHS->getType() != RHS->getType())
+ return Base;
+
+ // Collect all base pointers of LHS.
+ SmallPtrSet<Value *, 16> Ptrs;
+ Value *Ptr = LHS;
+ while (true) {
+ Ptrs.insert(Ptr);
+ if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
+ Ptr = GEP->getPointerOperand();
+ else
+ break;
+ }
+
+ // Find common base and collect RHS GEPs.
+ while (true) {
+ if (Ptrs.contains(RHS)) {
+ if (LHS->getType() != RHS->getType())
+ return Base;
+ Base.Ptr = RHS;
+ break;
+ }
+
+ if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
+ Base.RHSGEPs.push_back(GEP);
+ Base.RHSNW &= GEP->getNoWrapFlags();
+ RHS = GEP->getPointerOperand();
+ } else {
+ // No common base.
+ return Base;
+ }
+ }
+
+ // Collect LHS GEPs.
+ while (true) {
+ if (LHS == Base.Ptr)
+ break;
+
+ auto *GEP = cast<GEPOperator>(LHS);
+ Base.LHSGEPs.push_back(GEP);
+ Base.LHSNW &= GEP->getNoWrapFlags();
+ LHS = GEP->getPointerOperand();
+ }
+
+ return Base;
+}
+
/// Optimize pointer differences into the same array into a size. Consider:
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
Type *Ty, bool IsNUW) {
- // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
- // this.
- bool Swapped = false;
- GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
- if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
- std::swap(LHS, RHS);
- Swapped = true;
- }
-
- // Require at least one GEP with a common base pointer on both sides.
- if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
- // (gep X, ...) - X
- if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHS->stripPointerCasts()) {
- GEP1 = LHSGEP;
- } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
- // (gep X, ...) - (gep X, ...)
- if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHSGEP->getOperand(0)->stripPointerCasts()) {
- GEP1 = LHSGEP;
- GEP2 = RHSGEP;
- }
- }
- }
-
- if (!GEP1)
+ CommonBase Base = computeCommonBase(LHS, RHS);
+ if (!Base.Ptr)
return nullptr;
+
// To avoid duplicating the offset arithmetic, rewrite the GEP to use the
- // computed offset. This may erase the original GEP, so be sure to cache the
- // nowrap flags before emitting the offset.
+ // computed offset.
// TODO: We should probably do this even if there is only one GEP.
- bool RewriteGEPs = GEP2 != nullptr;
+ bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();
+
+ Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
+ auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
+ Value *Sum = nullptr;
+ for (GEPOperator *GEP : reverse(GEPs)) {
+ Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
+ if (Sum)
+ Sum = Builder.CreateAdd(Sum, Offset);
+ else
+ Sum = Offset;
+ }
+ if (!Sum)
+ return Constant::getNullValue(IdxTy);
+ return Sum;
+ };
- // Emit the offset of the GEP and an intptr_t.
- GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
- Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
+ Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
+ Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);
// If this is a single inbounds GEP and the original sub was nuw,
// then the final multiplication is also nuw.
if (auto *I = dyn_cast<Instruction>(Result))
- if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
+ if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
I->getOpcode() == Instruction::Mul)
I->setHasNoUnsignedWrap();
// If we have a 2nd GEP of the same base pointer, subtract the offsets.
// If both GEPs are inbounds, then the subtract does not have signed overflow.
// If both GEPs are nuw and the original sub is nuw, the new sub is also nuw.
- if (GEP2) {
- GEPNoWrapFlags GEP2NW = GEP2->getNoWrapFlags();
- Value *Offset = EmitGEPOffset(GEP2, RewriteGEPs);
- Result = Builder.CreateSub(Result, Offset, "gepdiff",
- IsNUW && GEP1NW.hasNoUnsignedWrap() &&
- GEP2NW.hasNoUnsignedWrap(),
- GEP1NW.isInBounds() && GEP2NW.isInBounds());
- }
-
- // If we have p - gep(p, ...) then we have to negate the result.
- if (Swapped)
- Result = Builder.CreateNeg(Result, "diff.neg");
+ if (!match(Offset2, m_Zero())) {
+ Result =
+ Builder.CreateSub(Result, Offset2, "gepdiff",
+ IsNUW && Base.LHSNW.hasNoUnsignedWrap() &&
+ Base.RHSNW.hasNoUnsignedWrap(),
+ Base.LHSNW.isInBounds() && Base.RHSNW.isInBounds());
+ }
return Builder.CreateIntCast(Result, Ty, true);
}
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index f5df8573d6304..365c17b35a468 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -506,8 +506,7 @@ define <2 x i1> @test23vec(<2 x i32> %x) {
; unsigned overflow does not happen during offset computation
define i1 @test24_neg_offs(ptr %p, i64 %offs) {
; CHECK-LABEL: @test24_neg_offs(
-; CHECK-NEXT: [[P1_IDX_NEG:%.*]] = mul i64 [[OFFS:%.*]], -4
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[P1_IDX_NEG]], 8
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OFFS:%.*]], -2
; CHECK-NEXT: ret i1 [[CMP]]
;
%p1 = getelementptr inbounds i32, ptr %p, i64 %offs
diff --git a/llvm/test/Transforms/InstCombine/sub-gep.ll b/llvm/test/Transforms/InstCombine/sub-gep.ll
index c86a1a37bd7ad..4a27b04f724d8 100644
--- a/llvm/test/Transforms/InstCombine/sub-gep.ll
+++ b/llvm/test/Transforms/InstCombine/sub-gep.ll
@@ -80,7 +80,7 @@ define i32 @test_inbounds_nuw_trunc(ptr %base, i64 %idx) {
define i64 @test_inbounds_nuw_swapped(ptr %base, i64 %idx) {
; CHECK-LABEL: @test_inbounds_nuw_swapped(
-; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
+; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
;
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -104,7 +104,7 @@ define i64 @test_inbounds1_nuw_swapped(ptr %base, i64 %idx) {
define i64 @test_inbounds2_nuw_swapped(ptr %base, i64 %idx) {
; CHECK-LABEL: @test_inbounds2_nuw_swapped(
-; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
+; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
;
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -279,8 +279,8 @@ define i16 @test24_as1(ptr addrspace(1) %P, i16 %A) {
define i64 @test24a(ptr %P, i64 %A){
; CHECK-LABEL: @test24a(
-; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i64 0, [[A:%.*]]
-; CHECK-NEXT: ret i64 [[DIFF_NEG]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i64 0, [[A:%.*]]
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
;
%B = getelementptr inbounds i8, ptr %P, i64 %A
%C = ptrtoint ptr %B to i64
@@ -291,8 +291,8 @@ define i64 @test24a(ptr %P, i64 %A){
define i16 @test24a_as1(ptr addrspace(1) %P, i16 %A) {
; CHECK-LABEL: @test24a_as1(
-; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i16 0, [[A:%.*]]
-; CHECK-NEXT: ret i16 [[DIFF_NEG]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i16 0, [[A:%.*]]
+; CHECK-NEXT: ret i16 [[GEPDIFF]]
;
%B = getelementptr inbounds i8, ptr addrspace(1) %P, i16 %A
%C = ptrtoint ptr addrspace(1) %B to i16
@@ -860,3 +860,85 @@ _Z3fooPKc.exit:
%tobool = icmp eq i64 %2, 0
ret i1 %tobool
}
+
+define i64 @multiple_geps_one_chain(ptr %base, i64 %idx, i64 %idx2) {
+; CHECK-LABEL: @multiple_geps_one_chain(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[D:%.*]] = shl i64 [[P2_IDX1]], 2
+; CHECK-NEXT: ret i64 [[D]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %i1 = ptrtoint ptr %base to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i2, %i1
+ ret i64 %d
+}
+
+define i64 @multiple_geps_one_chain_commuted(ptr %base, i64 %idx, i64 %idx2) {
+; CHECK-LABEL: @multiple_geps_one_chain_commuted(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[P2_IDX1]], -4
+; CHECK-NEXT: ret i64 [[DOTNEG]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %i1 = ptrtoint ptr %base to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i1, %i2
+ ret i64 %d
+}
+
+define i64 @multiple_geps_two_chains(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
+; CHECK-LABEL: @multiple_geps_two_chains(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
+ %i1 = ptrtoint ptr %p4 to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i2, %i1
+ ret i64 %d
+}
+
+define i64 @multiple_geps_two_chains_commuted(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
+; CHECK-LABEL: @multiple_geps_two_chains_commuted(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[IDX3:%.*]], [[P2_IDX1]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
+ %i1 = ptrtoint ptr %p4 to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i1, %i2
+ ret i64 %d
+}
+
+declare void @use(ptr)
+
+define i64 @multiple_geps_two_chains_gep_base(ptr %base, i64 %base.idx, i64 %idx, i64 %idx2, i64 %idx3) {
+; CHECK-LABEL: @multiple_geps_two_chains_gep_base(
+; CHECK-NEXT: [[GEP_BASE:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[BASE_IDX:%.*]]
+; CHECK-NEXT: call void @use(ptr [[GEP_BASE]])
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
+;
+ %gep.base = getelementptr inbounds i32, ptr %base, i64 %base.idx
+ call void @use(ptr %gep.base)
+ %p2 = getelementptr inbounds i32, ptr %gep.base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %p4 = getelementptr inbounds i32, ptr %gep.base, i64 %idx3
+ %i1 = ptrtoint ptr %p4 to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i2, %i1
+ ret i64 %d
+}
|
llvm-opt-benchmark results: dtcxzyw/llvm-opt-benchmark#2402 |
✅ With the latest revision this PR passed the C/C++ code formatter. |
37d8ac6
to
e4e42b7
Compare
|
||
// If this is a single inbounds GEP and the original sub was nuw, | ||
// then the final multiplication is also nuw. | ||
if (auto *I = dyn_cast<Instruction>(Result)) | ||
if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() && | ||
if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check nusw
instead? I don't know why Alive2 complains about this: https://alive2.llvm.org/ce/z/72ZFQa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Based on the proof, I assume this is commenting on the isInBounds use below, not this one.)
I think it's easier to understand if you consider the (ptradd(p, a) - ptradd(p, b))
case. With nusw, if p is sitting in the middle of the address space, you could have a as a large positive value and b as a large negative one, with overflow if you subtract them. With inbounds, it is guaranteed that the distance between a and b cannot exceed half the address space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
When expanding the offset of a GEP chain via a series of adds, try to preserve the nsw/nuw flags based on inbounds/nuw. This is a followup to #142958. Proof: https://alive2.llvm.org/ce/z/8HiFYY (note that preserving nsw in the nusw case is not valid)
…(#143488) When expanding the offset of a GEP chain via a series of adds, try to preserve the nsw/nuw flags based on inbounds/nuw. This is a followup to llvm/llvm-project#142958. Proof: https://alive2.llvm.org/ce/z/8HiFYY (note that preserving nsw in the nusw case is not valid)
…142958) Currently OptimizePointerDifference() only handles single GEPs with a common base, not GEP chains. This patch generalizes the support to nested GEPs with a common base. Finding the common base is a bit annoying because we want to stop as soon as possible and not recurse into common GEP prefixes. This helps avoids regressions from llvm#137297.
When expanding the offset of a GEP chain via a series of adds, try to preserve the nsw/nuw flags based on inbounds/nuw. This is a followup to llvm#142958. Proof: https://alive2.llvm.org/ce/z/8HiFYY (note that preserving nsw in the nusw case is not valid)
When looking for the common base pointer, support the case where the type changes because the GEP goes from pointer to vector of pointers. This was supported prior to llvm#142958. To correctly handle the multiple GEP case, allow specifying the index type to use in emitGEPOffset(), so we can perform all calculation in the vector type and thus avoid a type mismatch.
When looking for the common base pointer, support the case where the type changes because the GEP goes from pointer to vector of pointers. This was supported prior to llvm#142958.
…142958) Currently OptimizePointerDifference() only handles single GEPs with a common base, not GEP chains. This patch generalizes the support to nested GEPs with a common base. Finding the common base is a bit annoying because we want to stop as soon as possible and not recurse into common GEP prefixes. This helps avoids regressions from llvm#137297.
When expanding the offset of a GEP chain via a series of adds, try to preserve the nsw/nuw flags based on inbounds/nuw. This is a followup to llvm#142958. Proof: https://alive2.llvm.org/ce/z/8HiFYY (note that preserving nsw in the nusw case is not valid)
llvm#143906) When looking for the common base pointer, support the case where the type changes because the GEP goes from pointer to vector of pointers. This was supported prior to llvm#142958.
llvm#143906) When looking for the common base pointer, support the case where the type changes because the GEP goes from pointer to vector of pointers. This was supported prior to llvm#142958.
Currently OptimizePointerDifference() only handles single GEPs with a common base, not GEP chains. This patch generalizes the support to nested GEPs with a common base.
Finding the common base is a bit annoying because we want to stop as soon as possible and not recurse into common GEP prefixes.
This helps avoids regressions from #137297.