Skip to content

Commit 37d8ac6

Browse files
committed
Support OptimizePointerDifference for GEP chains
1 parent fc8f2c2 commit 37d8ac6

File tree

3 files changed

+121
-85
lines changed

3 files changed

+121
-85
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,71 +2068,119 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
20682068
return nullptr;
20692069
}
20702070

2071+
struct CommonBase {
2072+
/// Common base pointer.
2073+
Value *Ptr = nullptr;
2074+
/// LHS GEPs until common base.
2075+
SmallVector<GEPOperator *> LHSGEPs;
2076+
/// RHS GEPs until common base.
2077+
SmallVector<GEPOperator *> RHSGEPs;
2078+
/// LHS GEP NoWrapFlags until common base.
2079+
GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
2080+
/// RHS GEP NoWrapFlags until common base.
2081+
GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
2082+
};
2083+
2084+
static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
2085+
CommonBase Base;
2086+
2087+
if (LHS->getType() != RHS->getType())
2088+
return Base;
2089+
2090+
// Collect all base pointers of LHS.
2091+
SmallPtrSet<Value *, 16> Ptrs;
2092+
Value *Ptr = LHS;
2093+
while (true) {
2094+
Ptrs.insert(Ptr);
2095+
if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
2096+
Ptr = GEP->getPointerOperand();
2097+
else
2098+
break;
2099+
}
2100+
2101+
// Find common base and collect RHS GEPs.
2102+
while (true) {
2103+
if (Ptrs.contains(RHS)) {
2104+
if (LHS->getType() != RHS->getType())
2105+
return Base;
2106+
Base.Ptr = RHS;
2107+
break;
2108+
}
2109+
2110+
if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
2111+
Base.RHSGEPs.push_back(GEP);
2112+
Base.RHSNW &= GEP->getNoWrapFlags();
2113+
RHS = GEP->getPointerOperand();
2114+
} else {
2115+
// No common base.
2116+
return Base;
2117+
}
2118+
}
2119+
2120+
// Collect LHS GEPs.
2121+
while (true) {
2122+
if (LHS == Base.Ptr)
2123+
break;
2124+
2125+
auto *GEP = cast<GEPOperator>(LHS);
2126+
Base.LHSGEPs.push_back(GEP);
2127+
Base.LHSNW &= GEP->getNoWrapFlags();
2128+
LHS = GEP->getPointerOperand();
2129+
}
2130+
2131+
return Base;
2132+
}
2133+
20712134
/// Optimize pointer differences into the same array into a size. Consider:
20722135
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
20732136
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
20742137
Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
20752138
Type *Ty, bool IsNUW) {
2076-
// If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
2077-
// this.
2078-
bool Swapped = false;
2079-
GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
2080-
if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
2081-
std::swap(LHS, RHS);
2082-
Swapped = true;
2083-
}
2084-
2085-
// Require at least one GEP with a common base pointer on both sides.
2086-
if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
2087-
// (gep X, ...) - X
2088-
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
2089-
RHS->stripPointerCasts()) {
2090-
GEP1 = LHSGEP;
2091-
} else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
2092-
// (gep X, ...) - (gep X, ...)
2093-
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
2094-
RHSGEP->getOperand(0)->stripPointerCasts()) {
2095-
GEP1 = LHSGEP;
2096-
GEP2 = RHSGEP;
2097-
}
2098-
}
2099-
}
2100-
2101-
if (!GEP1)
2139+
CommonBase Base = computeCommonBase(LHS, RHS);
2140+
if (!Base.Ptr)
21022141
return nullptr;
21032142

2143+
21042144
// To avoid duplicating the offset arithmetic, rewrite the GEP to use the
2105-
// computed offset. This may erase the original GEP, so be sure to cache the
2106-
// nowrap flags before emitting the offset.
2145+
// computed offset.
21072146
// TODO: We should probably do this even if there is only one GEP.
2108-
bool RewriteGEPs = GEP2 != nullptr;
2147+
bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();
2148+
2149+
Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
2150+
auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
2151+
Value *Sum = nullptr;
2152+
for (GEPOperator *GEP : reverse(GEPs)) {
2153+
Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
2154+
if (Sum)
2155+
Sum = Builder.CreateAdd(Sum, Offset);
2156+
else
2157+
Sum = Offset;
2158+
}
2159+
if (!Sum)
2160+
return Constant::getNullValue(IdxTy);
2161+
return Sum;
2162+
};
21092163

2110-
// Emit the offset of the GEP and an intptr_t.
2111-
GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
2112-
Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
2164+
Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
2165+
Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);
21132166

21142167
// If this is a single inbounds GEP and the original sub was nuw,
21152168
// then the final multiplication is also nuw.
21162169
if (auto *I = dyn_cast<Instruction>(Result))
2117-
if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
2170+
if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
21182171
I->getOpcode() == Instruction::Mul)
21192172
I->setHasNoUnsignedWrap();
21202173

21212174
// If we have a 2nd GEP of the same base pointer, subtract the offsets.
21222175
// If both GEPs are inbounds, then the subtract does not have signed overflow.
21232176
// If both GEPs are nuw and the original sub is nuw, the new sub is also nuw.
2124-
if (GEP2) {
2125-
GEPNoWrapFlags GEP2NW = GEP2->getNoWrapFlags();
2126-
Value *Offset = EmitGEPOffset(GEP2, RewriteGEPs);
2127-
Result = Builder.CreateSub(Result, Offset, "gepdiff",
2128-
IsNUW && GEP1NW.hasNoUnsignedWrap() &&
2129-
GEP2NW.hasNoUnsignedWrap(),
2130-
GEP1NW.isInBounds() && GEP2NW.isInBounds());
2131-
}
2132-
2133-
// If we have p - gep(p, ...) then we have to negate the result.
2134-
if (Swapped)
2135-
Result = Builder.CreateNeg(Result, "diff.neg");
2177+
if (!match(Offset2, m_Zero())) {
2178+
Result =
2179+
Builder.CreateSub(Result, Offset2, "gepdiff",
2180+
IsNUW && Base.LHSNW.hasNoUnsignedWrap() &&
2181+
Base.RHSNW.hasNoUnsignedWrap(),
2182+
Base.LHSNW.isInBounds() && Base.RHSNW.isInBounds());
2183+
}
21362184

21372185
return Builder.CreateIntCast(Result, Ty, true);
21382186
}

llvm/test/Transforms/InstCombine/icmp.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,7 @@ define <2 x i1> @test23vec(<2 x i32> %x) {
506506
; unsigned overflow does not happen during offset computation
507507
define i1 @test24_neg_offs(ptr %p, i64 %offs) {
508508
; CHECK-LABEL: @test24_neg_offs(
509-
; CHECK-NEXT: [[P1_IDX_NEG:%.*]] = mul i64 [[OFFS:%.*]], -4
510-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[P1_IDX_NEG]], 8
509+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OFFS:%.*]], -2
511510
; CHECK-NEXT: ret i1 [[CMP]]
512511
;
513512
%p1 = getelementptr inbounds i32, ptr %p, i64 %offs

llvm/test/Transforms/InstCombine/sub-gep.ll

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ define i32 @test_inbounds_nuw_trunc(ptr %base, i64 %idx) {
8080

8181
define i64 @test_inbounds_nuw_swapped(ptr %base, i64 %idx) {
8282
; CHECK-LABEL: @test_inbounds_nuw_swapped(
83-
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
83+
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
8484
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
8585
;
8686
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -104,7 +104,7 @@ define i64 @test_inbounds1_nuw_swapped(ptr %base, i64 %idx) {
104104

105105
define i64 @test_inbounds2_nuw_swapped(ptr %base, i64 %idx) {
106106
; CHECK-LABEL: @test_inbounds2_nuw_swapped(
107-
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
107+
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
108108
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
109109
;
110110
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -279,8 +279,8 @@ define i16 @test24_as1(ptr addrspace(1) %P, i16 %A) {
279279

280280
define i64 @test24a(ptr %P, i64 %A){
281281
; CHECK-LABEL: @test24a(
282-
; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i64 0, [[A:%.*]]
283-
; CHECK-NEXT: ret i64 [[DIFF_NEG]]
282+
; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i64 0, [[A:%.*]]
283+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
284284
;
285285
%B = getelementptr inbounds i8, ptr %P, i64 %A
286286
%C = ptrtoint ptr %B to i64
@@ -291,8 +291,8 @@ define i64 @test24a(ptr %P, i64 %A){
291291

292292
define i16 @test24a_as1(ptr addrspace(1) %P, i16 %A) {
293293
; CHECK-LABEL: @test24a_as1(
294-
; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i16 0, [[A:%.*]]
295-
; CHECK-NEXT: ret i16 [[DIFF_NEG]]
294+
; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i16 0, [[A:%.*]]
295+
; CHECK-NEXT: ret i16 [[GEPDIFF]]
296296
;
297297
%B = getelementptr inbounds i8, ptr addrspace(1) %P, i16 %A
298298
%C = ptrtoint ptr addrspace(1) %B to i16
@@ -863,11 +863,8 @@ _Z3fooPKc.exit:
863863

864864
define i64 @multiple_geps_one_chain(ptr %base, i64 %idx, i64 %idx2) {
865865
; CHECK-LABEL: @multiple_geps_one_chain(
866-
; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[IDX:%.*]]
867-
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, ptr [[P2]], i64 [[IDX2:%.*]]
868-
; CHECK-NEXT: [[I1:%.*]] = ptrtoint ptr [[BASE]] to i64
869-
; CHECK-NEXT: [[I2:%.*]] = ptrtoint ptr [[P3]] to i64
870-
; CHECK-NEXT: [[D:%.*]] = sub i64 [[I2]], [[I1]]
866+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
867+
; CHECK-NEXT: [[D:%.*]] = shl i64 [[P2_IDX1]], 2
871868
; CHECK-NEXT: ret i64 [[D]]
872869
;
873870
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
@@ -880,12 +877,9 @@ define i64 @multiple_geps_one_chain(ptr %base, i64 %idx, i64 %idx2) {
880877

881878
define i64 @multiple_geps_one_chain_commuted(ptr %base, i64 %idx, i64 %idx2) {
882879
; CHECK-LABEL: @multiple_geps_one_chain_commuted(
883-
; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[IDX:%.*]]
884-
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, ptr [[P2]], i64 [[IDX2:%.*]]
885-
; CHECK-NEXT: [[I1:%.*]] = ptrtoint ptr [[BASE]] to i64
886-
; CHECK-NEXT: [[I2:%.*]] = ptrtoint ptr [[P3]] to i64
887-
; CHECK-NEXT: [[D:%.*]] = sub i64 [[I1]], [[I2]]
888-
; CHECK-NEXT: ret i64 [[D]]
880+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
881+
; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[P2_IDX1]], -4
882+
; CHECK-NEXT: ret i64 [[DOTNEG]]
889883
;
890884
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
891885
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
@@ -897,13 +891,10 @@ define i64 @multiple_geps_one_chain_commuted(ptr %base, i64 %idx, i64 %idx2) {
897891

898892
define i64 @multiple_geps_two_chains(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
899893
; CHECK-LABEL: @multiple_geps_two_chains(
900-
; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[IDX:%.*]]
901-
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, ptr [[P2]], i64 [[IDX2:%.*]]
902-
; CHECK-NEXT: [[P4:%.*]] = getelementptr inbounds i32, ptr [[BASE]], i64 [[IDX3:%.*]]
903-
; CHECK-NEXT: [[I1:%.*]] = ptrtoint ptr [[P4]] to i64
904-
; CHECK-NEXT: [[I2:%.*]] = ptrtoint ptr [[P3]] to i64
905-
; CHECK-NEXT: [[D:%.*]] = sub i64 [[I2]], [[I1]]
906-
; CHECK-NEXT: ret i64 [[D]]
894+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
895+
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
896+
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
897+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
907898
;
908899
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
909900
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
@@ -916,13 +907,10 @@ define i64 @multiple_geps_two_chains(ptr %base, i64 %idx, i64 %idx2, i64 %idx3)
916907

917908
define i64 @multiple_geps_two_chains_commuted(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
918909
; CHECK-LABEL: @multiple_geps_two_chains_commuted(
919-
; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[IDX:%.*]]
920-
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, ptr [[P2]], i64 [[IDX2:%.*]]
921-
; CHECK-NEXT: [[P4:%.*]] = getelementptr inbounds i32, ptr [[BASE]], i64 [[IDX3:%.*]]
922-
; CHECK-NEXT: [[I1:%.*]] = ptrtoint ptr [[P4]] to i64
923-
; CHECK-NEXT: [[I2:%.*]] = ptrtoint ptr [[P3]] to i64
924-
; CHECK-NEXT: [[D:%.*]] = sub i64 [[I1]], [[I2]]
925-
; CHECK-NEXT: ret i64 [[D]]
910+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
911+
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[IDX3:%.*]], [[P2_IDX1]]
912+
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
913+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
926914
;
927915
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
928916
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
@@ -933,18 +921,19 @@ define i64 @multiple_geps_two_chains_commuted(ptr %base, i64 %idx, i64 %idx2, i6
933921
ret i64 %d
934922
}
935923

924+
declare void @use(ptr)
925+
936926
define i64 @multiple_geps_two_chains_gep_base(ptr %base, i64 %base.idx, i64 %idx, i64 %idx2, i64 %idx3) {
937927
; CHECK-LABEL: @multiple_geps_two_chains_gep_base(
938928
; CHECK-NEXT: [[GEP_BASE:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[BASE_IDX:%.*]]
939-
; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, ptr [[GEP_BASE]], i64 [[IDX:%.*]]
940-
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, ptr [[P2]], i64 [[IDX2:%.*]]
941-
; CHECK-NEXT: [[P4:%.*]] = getelementptr inbounds i32, ptr [[GEP_BASE]], i64 [[IDX3:%.*]]
942-
; CHECK-NEXT: [[I1:%.*]] = ptrtoint ptr [[P4]] to i64
943-
; CHECK-NEXT: [[I2:%.*]] = ptrtoint ptr [[P3]] to i64
944-
; CHECK-NEXT: [[D:%.*]] = sub i64 [[I2]], [[I1]]
945-
; CHECK-NEXT: ret i64 [[D]]
929+
; CHECK-NEXT: call void @use(ptr [[GEP_BASE]])
930+
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
931+
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
932+
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
933+
; CHECK-NEXT: ret i64 [[GEPDIFF]]
946934
;
947935
%gep.base = getelementptr inbounds i32, ptr %base, i64 %base.idx
936+
call void @use(ptr %gep.base)
948937
%p2 = getelementptr inbounds i32, ptr %gep.base, i64 %idx
949938
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
950939
%p4 = getelementptr inbounds i32, ptr %gep.base, i64 %idx3

0 commit comments

Comments
 (0)