Skip to content

Commit 20e8de9

Browse files
authored
[InstCombine] Support nested GEPs in OptimizePointerDifference (#142958)
Currently OptimizePointerDifference() only handles single GEPs with a common base, not GEP chains. This patch generalizes the support to nested GEPs with a common base. Finding the common base is a bit annoying because we want to stop as soon as possible and not recurse into common GEP prefixes. This helps avoids regressions from #137297.
1 parent ecc8b29 commit 20e8de9

File tree

3 files changed

+181
-53
lines changed

3 files changed

+181
-53
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,71 +2068,118 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
20682068
return nullptr;
20692069
}
20702070

2071+
struct CommonBase {
2072+
/// Common base pointer.
2073+
Value *Ptr = nullptr;
2074+
/// LHS GEPs until common base.
2075+
SmallVector<GEPOperator *> LHSGEPs;
2076+
/// RHS GEPs until common base.
2077+
SmallVector<GEPOperator *> RHSGEPs;
2078+
/// LHS GEP NoWrapFlags until common base.
2079+
GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
2080+
/// RHS GEP NoWrapFlags until common base.
2081+
GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
2082+
};
2083+
2084+
static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
2085+
CommonBase Base;
2086+
2087+
if (LHS->getType() != RHS->getType())
2088+
return Base;
2089+
2090+
// Collect all base pointers of LHS.
2091+
SmallPtrSet<Value *, 16> Ptrs;
2092+
Value *Ptr = LHS;
2093+
while (true) {
2094+
Ptrs.insert(Ptr);
2095+
if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
2096+
Ptr = GEP->getPointerOperand();
2097+
else
2098+
break;
2099+
}
2100+
2101+
// Find common base and collect RHS GEPs.
2102+
while (true) {
2103+
if (Ptrs.contains(RHS)) {
2104+
if (LHS->getType() != RHS->getType())
2105+
return Base;
2106+
Base.Ptr = RHS;
2107+
break;
2108+
}
2109+
2110+
if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
2111+
Base.RHSGEPs.push_back(GEP);
2112+
Base.RHSNW &= GEP->getNoWrapFlags();
2113+
RHS = GEP->getPointerOperand();
2114+
} else {
2115+
// No common base.
2116+
return Base;
2117+
}
2118+
}
2119+
2120+
// Collect LHS GEPs.
2121+
while (true) {
2122+
if (LHS == Base.Ptr)
2123+
break;
2124+
2125+
auto *GEP = cast<GEPOperator>(LHS);
2126+
Base.LHSGEPs.push_back(GEP);
2127+
Base.LHSNW &= GEP->getNoWrapFlags();
2128+
LHS = GEP->getPointerOperand();
2129+
}
2130+
2131+
return Base;
2132+
}
2133+
20712134
/// Optimize pointer differences into the same array into a size. Consider:
20722135
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
20732136
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
20742137
Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
20752138
Type *Ty, bool IsNUW) {
2076-
// If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
2077-
// this.
2078-
bool Swapped = false;
2079-
GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
2080-
if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
2081-
std::swap(LHS, RHS);
2082-
Swapped = true;
2083-
}
2084-
2085-
// Require at least one GEP with a common base pointer on both sides.
2086-
if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
2087-
// (gep X, ...) - X
2088-
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
2089-
RHS->stripPointerCasts()) {
2090-
GEP1 = LHSGEP;
2091-
} else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
2092-
// (gep X, ...) - (gep X, ...)
2093-
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
2094-
RHSGEP->getOperand(0)->stripPointerCasts()) {
2095-
GEP1 = LHSGEP;
2096-
GEP2 = RHSGEP;
2097-
}
2098-
}
2099-
}
2100-
2101-
if (!GEP1)
2139+
CommonBase Base = computeCommonBase(LHS, RHS);
2140+
if (!Base.Ptr)
21022141
return nullptr;
21032142

21042143
// To avoid duplicating the offset arithmetic, rewrite the GEP to use the
2105-
// computed offset. This may erase the original GEP, so be sure to cache the
2106-
// nowrap flags before emitting the offset.
2144+
// computed offset.
21072145
// TODO: We should probably do this even if there is only one GEP.
2108-
bool RewriteGEPs = GEP2 != nullptr;
2146+
bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();
2147+
2148+
Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
2149+
auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
2150+
Value *Sum = nullptr;
2151+
for (GEPOperator *GEP : reverse(GEPs)) {
2152+
Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
2153+
if (Sum)
2154+
Sum = Builder.CreateAdd(Sum, Offset);
2155+
else
2156+
Sum = Offset;
2157+
}
2158+
if (!Sum)
2159+
return Constant::getNullValue(IdxTy);
2160+
return Sum;
2161+
};
21092162

2110-
// Emit the offset of the GEP and an intptr_t.
2111-
GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
2112-
Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
2163+
Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
2164+
Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);
21132165

21142166
// If this is a single inbounds GEP and the original sub was nuw,
21152167
// then the final multiplication is also nuw.
21162168
if (auto *I = dyn_cast<Instruction>(Result))
2117-
if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
2169+
if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
21182170
I->getOpcode() == Instruction::Mul)
21192171
I->setHasNoUnsignedWrap();
21202172

21212173
// If we have a 2nd GEP of the same base pointer, subtract the offsets.
21222174
// If both GEPs are inbounds, then the subtract does not have signed overflow.
21232175
// If both GEPs are nuw and the original sub is nuw, the new sub is also nuw.
2124-
if (GEP2) {
2125-
GEPNoWrapFlags GEP2NW = GEP2->getNoWrapFlags();
2126-
Value *Offset = EmitGEPOffset(GEP2, RewriteGEPs);
2127-
Result = Builder.CreateSub(Result, Offset, "gepdiff",
2128-
IsNUW && GEP1NW.hasNoUnsignedWrap() &&
2129-
GEP2NW.hasNoUnsignedWrap(),
2130-
GEP1NW.isInBounds() && GEP2NW.isInBounds());
2131-
}
2132-
2133-
// If we have p - gep(p, ...) then we have to negate the result.
2134-
if (Swapped)
2135-
Result = Builder.CreateNeg(Result, "diff.neg");
2176+
if (!match(Offset2, m_Zero())) {
2177+
Result =
2178+
Builder.CreateSub(Result, Offset2, "gepdiff",
2179+
IsNUW && Base.LHSNW.hasNoUnsignedWrap() &&
2180+
Base.RHSNW.hasNoUnsignedWrap(),
2181+
Base.LHSNW.isInBounds() && Base.RHSNW.isInBounds());
2182+
}
21362183

21372184
return Builder.CreateIntCast(Result, Ty, true);
21382185
}

llvm/test/Transforms/InstCombine/icmp.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,7 @@ define <2 x i1> @test23vec(<2 x i32> %x) {
506506
; unsigned overflow does not happen during offset computation
507507
define i1 @test24_neg_offs(ptr %p, i64 %offs) {
508508
; CHECK-LABEL: @test24_neg_offs(
509-
; CHECK-NEXT: [[P1_IDX_NEG:%.*]] = mul i64 [[OFFS:%.*]], -4
510-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[P1_IDX_NEG]], 8
509+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OFFS:%.*]], -2
511510
; CHECK-NEXT: ret i1 [[CMP]]
512511
;
513512
%p1 = getelementptr inbounds i32, ptr %p, i64 %offs

llvm/test/Transforms/InstCombine/sub-gep.ll

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ define i32 @test_inbounds_nuw_trunc(ptr %base, i64 %idx) {
8080

8181
define i64 @test_inbounds_nuw_swapped(ptr %base, i64 %idx) {
8282
; CHECK-LABEL: @test_inbounds_nuw_swapped(
83-
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
83+
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
8484
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
8585
;
8686
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -104,7 +104,7 @@ define i64 @test_inbounds1_nuw_swapped(ptr %base, i64 %idx) {
104104

105105
define i64 @test_inbounds2_nuw_swapped(ptr %base, i64 %idx) {
106106
; CHECK-LABEL: @test_inbounds2_nuw_swapped(
107-
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
107+
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
108108
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
109109
;
110110
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -279,8 +279,8 @@ define i16 @test24_as1(ptr addrspace(1) %P, i16 %A) {
279279

280280
define i64 @test24a(ptr %P, i64 %A){
281281
; CHECK-LABEL: @test24a(
282-
; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i64 0, [[A:%.*]]
283-
; CHECK-NEXT: ret i64 [[DIFF_NEG]]
282+
; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i64 0, [[A:%.*]]
283+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
284284
;
285285
%B = getelementptr inbounds i8, ptr %P, i64 %A
286286
%C = ptrtoint ptr %B to i64
@@ -291,8 +291,8 @@ define i64 @test24a(ptr %P, i64 %A){
291291

292292
define i16 @test24a_as1(ptr addrspace(1) %P, i16 %A) {
293293
; CHECK-LABEL: @test24a_as1(
294-
; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i16 0, [[A:%.*]]
295-
; CHECK-NEXT: ret i16 [[DIFF_NEG]]
294+
; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i16 0, [[A:%.*]]
295+
; CHECK-NEXT: ret i16 [[GEPDIFF]]
296296
;
297297
%B = getelementptr inbounds i8, ptr addrspace(1) %P, i16 %A
298298
%C = ptrtoint ptr addrspace(1) %B to i16
@@ -860,3 +860,85 @@ _Z3fooPKc.exit:
860860
%tobool = icmp eq i64 %2, 0
861861
ret i1 %tobool
862862
}
863+
864+
define i64 @multiple_geps_one_chain(ptr %base, i64 %idx, i64 %idx2) {
865+
; CHECK-LABEL: @multiple_geps_one_chain(
866+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
867+
; CHECK-NEXT: [[D:%.*]] = shl i64 [[P2_IDX1]], 2
868+
; CHECK-NEXT: ret i64 [[D]]
869+
;
870+
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
871+
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
872+
%i1 = ptrtoint ptr %base to i64
873+
%i2 = ptrtoint ptr %p3 to i64
874+
%d = sub i64 %i2, %i1
875+
ret i64 %d
876+
}
877+
878+
define i64 @multiple_geps_one_chain_commuted(ptr %base, i64 %idx, i64 %idx2) {
879+
; CHECK-LABEL: @multiple_geps_one_chain_commuted(
880+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
881+
; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[P2_IDX1]], -4
882+
; CHECK-NEXT: ret i64 [[DOTNEG]]
883+
;
884+
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
885+
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
886+
%i1 = ptrtoint ptr %base to i64
887+
%i2 = ptrtoint ptr %p3 to i64
888+
%d = sub i64 %i1, %i2
889+
ret i64 %d
890+
}
891+
892+
define i64 @multiple_geps_two_chains(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
893+
; CHECK-LABEL: @multiple_geps_two_chains(
894+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
895+
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
896+
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
897+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
898+
;
899+
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
900+
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
901+
%p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
902+
%i1 = ptrtoint ptr %p4 to i64
903+
%i2 = ptrtoint ptr %p3 to i64
904+
%d = sub i64 %i2, %i1
905+
ret i64 %d
906+
}
907+
908+
define i64 @multiple_geps_two_chains_commuted(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
909+
; CHECK-LABEL: @multiple_geps_two_chains_commuted(
910+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
911+
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[IDX3:%.*]], [[P2_IDX1]]
912+
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
913+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
914+
;
915+
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
916+
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
917+
%p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
918+
%i1 = ptrtoint ptr %p4 to i64
919+
%i2 = ptrtoint ptr %p3 to i64
920+
%d = sub i64 %i1, %i2
921+
ret i64 %d
922+
}
923+
924+
declare void @use(ptr)
925+
926+
define i64 @multiple_geps_two_chains_gep_base(ptr %base, i64 %base.idx, i64 %idx, i64 %idx2, i64 %idx3) {
927+
; CHECK-LABEL: @multiple_geps_two_chains_gep_base(
928+
; CHECK-NEXT: [[GEP_BASE:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[BASE_IDX:%.*]]
929+
; CHECK-NEXT: call void @use(ptr [[GEP_BASE]])
930+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
931+
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
932+
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
933+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
934+
;
935+
%gep.base = getelementptr inbounds i32, ptr %base, i64 %base.idx
936+
call void @use(ptr %gep.base)
937+
%p2 = getelementptr inbounds i32, ptr %gep.base, i64 %idx
938+
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
939+
%p4 = getelementptr inbounds i32, ptr %gep.base, i64 %idx3
940+
%i1 = ptrtoint ptr %p4 to i64
941+
%i2 = ptrtoint ptr %p3 to i64
942+
%d = sub i64 %i2, %i1
943+
ret i64 %d
944+
}

0 commit comments

Comments
 (0)