Skip to content

[InstCombine] Support nested GEPs in OptimizePointerDifference #142958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 92 additions & 45 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2068,71 +2068,118 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return nullptr;
}

struct CommonBase {
/// Common base pointer.
Value *Ptr = nullptr;
/// LHS GEPs until common base.
SmallVector<GEPOperator *> LHSGEPs;
/// RHS GEPs until common base.
SmallVector<GEPOperator *> RHSGEPs;
/// LHS GEP NoWrapFlags until common base.
GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
/// RHS GEP NoWrapFlags until common base.
GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
};

static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
CommonBase Base;

if (LHS->getType() != RHS->getType())
return Base;

// Collect all base pointers of LHS.
SmallPtrSet<Value *, 16> Ptrs;
Value *Ptr = LHS;
while (true) {
Ptrs.insert(Ptr);
if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
Ptr = GEP->getPointerOperand();
else
break;
}

// Find common base and collect RHS GEPs.
while (true) {
if (Ptrs.contains(RHS)) {
if (LHS->getType() != RHS->getType())
return Base;
Base.Ptr = RHS;
break;
}

if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
Base.RHSGEPs.push_back(GEP);
Base.RHSNW &= GEP->getNoWrapFlags();
RHS = GEP->getPointerOperand();
} else {
// No common base.
return Base;
}
}

// Collect LHS GEPs.
while (true) {
if (LHS == Base.Ptr)
break;

auto *GEP = cast<GEPOperator>(LHS);
Base.LHSGEPs.push_back(GEP);
Base.LHSNW &= GEP->getNoWrapFlags();
LHS = GEP->getPointerOperand();
}

return Base;
}

/// Optimize pointer differences into the same array into a size. Consider:
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
Type *Ty, bool IsNUW) {
// If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
// this.
bool Swapped = false;
GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
std::swap(LHS, RHS);
Swapped = true;
}

// Require at least one GEP with a common base pointer on both sides.
if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
// (gep X, ...) - X
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
RHS->stripPointerCasts()) {
GEP1 = LHSGEP;
} else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
// (gep X, ...) - (gep X, ...)
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
RHSGEP->getOperand(0)->stripPointerCasts()) {
GEP1 = LHSGEP;
GEP2 = RHSGEP;
}
}
}

if (!GEP1)
CommonBase Base = computeCommonBase(LHS, RHS);
if (!Base.Ptr)
return nullptr;

// To avoid duplicating the offset arithmetic, rewrite the GEP to use the
// computed offset. This may erase the original GEP, so be sure to cache the
// nowrap flags before emitting the offset.
// computed offset.
// TODO: We should probably do this even if there is only one GEP.
bool RewriteGEPs = GEP2 != nullptr;
bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();

Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
Value *Sum = nullptr;
for (GEPOperator *GEP : reverse(GEPs)) {
Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
if (Sum)
Sum = Builder.CreateAdd(Sum, Offset);
else
Sum = Offset;
}
if (!Sum)
return Constant::getNullValue(IdxTy);
return Sum;
};

// Emit the offset of the GEP and an intptr_t.
GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);

// If this is a single inbounds GEP and the original sub was nuw,
// then the final multiplication is also nuw.
if (auto *I = dyn_cast<Instruction>(Result))
if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check nusw instead? I don't know why Alive2 complains about this: https://alive2.llvm.org/ce/z/72ZFQa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Based on the proof, I assume this is commenting on the isInBounds use below, not this one.)

I think it's easier to understand if you consider the (ptradd(p, a) - ptradd(p, b)) case. With nusw, if p is sitting in the middle of the address space, you could have a as a large positive value and b as a large negative one, with overflow if you subtract them. With inbounds, it is guaranteed that the distance between a and b cannot exceed half the address space.

I->getOpcode() == Instruction::Mul)
I->setHasNoUnsignedWrap();

// If we have a 2nd GEP of the same base pointer, subtract the offsets.
// If both GEPs are inbounds, then the subtract does not have signed overflow.
// If both GEPs are nuw and the original sub is nuw, the new sub is also nuw.
if (GEP2) {
GEPNoWrapFlags GEP2NW = GEP2->getNoWrapFlags();
Value *Offset = EmitGEPOffset(GEP2, RewriteGEPs);
Result = Builder.CreateSub(Result, Offset, "gepdiff",
IsNUW && GEP1NW.hasNoUnsignedWrap() &&
GEP2NW.hasNoUnsignedWrap(),
GEP1NW.isInBounds() && GEP2NW.isInBounds());
}

// If we have p - gep(p, ...) then we have to negate the result.
if (Swapped)
Result = Builder.CreateNeg(Result, "diff.neg");
if (!match(Offset2, m_Zero())) {
Result =
Builder.CreateSub(Result, Offset2, "gepdiff",
IsNUW && Base.LHSNW.hasNoUnsignedWrap() &&
Base.RHSNW.hasNoUnsignedWrap(),
Base.LHSNW.isInBounds() && Base.RHSNW.isInBounds());
}

return Builder.CreateIntCast(Result, Ty, true);
}
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/Transforms/InstCombine/icmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,7 @@ define <2 x i1> @test23vec(<2 x i32> %x) {
; unsigned overflow does not happen during offset computation
define i1 @test24_neg_offs(ptr %p, i64 %offs) {
; CHECK-LABEL: @test24_neg_offs(
; CHECK-NEXT: [[P1_IDX_NEG:%.*]] = mul i64 [[OFFS:%.*]], -4
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[P1_IDX_NEG]], 8
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OFFS:%.*]], -2
; CHECK-NEXT: ret i1 [[CMP]]
;
%p1 = getelementptr inbounds i32, ptr %p, i64 %offs
Expand Down
94 changes: 88 additions & 6 deletions llvm/test/Transforms/InstCombine/sub-gep.ll
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ define i32 @test_inbounds_nuw_trunc(ptr %base, i64 %idx) {

define i64 @test_inbounds_nuw_swapped(ptr %base, i64 %idx) {
; CHECK-LABEL: @test_inbounds_nuw_swapped(
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
;
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
Expand All @@ -104,7 +104,7 @@ define i64 @test_inbounds1_nuw_swapped(ptr %base, i64 %idx) {

define i64 @test_inbounds2_nuw_swapped(ptr %base, i64 %idx) {
; CHECK-LABEL: @test_inbounds2_nuw_swapped(
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
;
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
Expand Down Expand Up @@ -279,8 +279,8 @@ define i16 @test24_as1(ptr addrspace(1) %P, i16 %A) {

define i64 @test24a(ptr %P, i64 %A){
; CHECK-LABEL: @test24a(
; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i64 0, [[A:%.*]]
; CHECK-NEXT: ret i64 [[DIFF_NEG]]
; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i64 0, [[A:%.*]]
; CHECK-NEXT: ret i64 [[GEPDIFF]]
;
%B = getelementptr inbounds i8, ptr %P, i64 %A
%C = ptrtoint ptr %B to i64
Expand All @@ -291,8 +291,8 @@ define i64 @test24a(ptr %P, i64 %A){

define i16 @test24a_as1(ptr addrspace(1) %P, i16 %A) {
; CHECK-LABEL: @test24a_as1(
; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i16 0, [[A:%.*]]
; CHECK-NEXT: ret i16 [[DIFF_NEG]]
; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i16 0, [[A:%.*]]
; CHECK-NEXT: ret i16 [[GEPDIFF]]
;
%B = getelementptr inbounds i8, ptr addrspace(1) %P, i16 %A
%C = ptrtoint ptr addrspace(1) %B to i16
Expand Down Expand Up @@ -860,3 +860,85 @@ _Z3fooPKc.exit:
%tobool = icmp eq i64 %2, 0
ret i1 %tobool
}

define i64 @multiple_geps_one_chain(ptr %base, i64 %idx, i64 %idx2) {
; CHECK-LABEL: @multiple_geps_one_chain(
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
; CHECK-NEXT: [[D:%.*]] = shl i64 [[P2_IDX1]], 2
; CHECK-NEXT: ret i64 [[D]]
;
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
%i1 = ptrtoint ptr %base to i64
%i2 = ptrtoint ptr %p3 to i64
%d = sub i64 %i2, %i1
ret i64 %d
}

define i64 @multiple_geps_one_chain_commuted(ptr %base, i64 %idx, i64 %idx2) {
; CHECK-LABEL: @multiple_geps_one_chain_commuted(
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[P2_IDX1]], -4
; CHECK-NEXT: ret i64 [[DOTNEG]]
;
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
%i1 = ptrtoint ptr %base to i64
%i2 = ptrtoint ptr %p3 to i64
%d = sub i64 %i1, %i2
ret i64 %d
}

define i64 @multiple_geps_two_chains(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
; CHECK-LABEL: @multiple_geps_two_chains(
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
; CHECK-NEXT: ret i64 [[GEPDIFF]]
;
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
%p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
%i1 = ptrtoint ptr %p4 to i64
%i2 = ptrtoint ptr %p3 to i64
%d = sub i64 %i2, %i1
ret i64 %d
}

define i64 @multiple_geps_two_chains_commuted(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
; CHECK-LABEL: @multiple_geps_two_chains_commuted(
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[IDX3:%.*]], [[P2_IDX1]]
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
; CHECK-NEXT: ret i64 [[GEPDIFF]]
;
%p2 = getelementptr inbounds i32, ptr %base, i64 %idx
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
%p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
%i1 = ptrtoint ptr %p4 to i64
%i2 = ptrtoint ptr %p3 to i64
%d = sub i64 %i1, %i2
ret i64 %d
}

declare void @use(ptr)

define i64 @multiple_geps_two_chains_gep_base(ptr %base, i64 %base.idx, i64 %idx, i64 %idx2, i64 %idx3) {
; CHECK-LABEL: @multiple_geps_two_chains_gep_base(
; CHECK-NEXT: [[GEP_BASE:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[BASE_IDX:%.*]]
; CHECK-NEXT: call void @use(ptr [[GEP_BASE]])
; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
; CHECK-NEXT: ret i64 [[GEPDIFF]]
;
%gep.base = getelementptr inbounds i32, ptr %base, i64 %base.idx
call void @use(ptr %gep.base)
%p2 = getelementptr inbounds i32, ptr %gep.base, i64 %idx
%p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
%p4 = getelementptr inbounds i32, ptr %gep.base, i64 %idx3
%i1 = ptrtoint ptr %p4 to i64
%i2 = ptrtoint ptr %p3 to i64
%d = sub i64 %i2, %i1
ret i64 %d
}
Loading