Skip to content

[SROA] Unfold gep of index select #80983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3937,30 +3937,62 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
return false;
}

// Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2)
// Fold gep (select cond, ptr1, ptr2), idx
// => select cond, gep(ptr1, idx), gep(ptr2, idx)
// and gep ptr, (select cond, idx1, idx2)
// => select cond, gep(ptr, idx1), gep(ptr, idx2)
bool foldGEPSelect(GetElementPtrInst &GEPI) {
if (!GEPI.hasAllConstantIndices())
return false;
// Check whether the GEP has exactly one select operand and all indices
// will become constant after the transform.
SelectInst *Sel = dyn_cast<SelectInst>(GEPI.getPointerOperand());
for (Value *Op : GEPI.indices()) {
if (auto *SI = dyn_cast<SelectInst>(Op)) {
if (Sel)
return false;

Sel = SI;
if (!isa<ConstantInt>(Sel->getTrueValue()) ||
!isa<ConstantInt>(Sel->getFalseValue()))
return false;
continue;
}

SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand());
if (!isa<ConstantInt>(Op))
return false;
}

if (!Sel)
return false;

LLVM_DEBUG(dbgs() << " Rewriting gep(select) -> select(gep):"
<< "\n original: " << *Sel
<< "\n " << GEPI);

auto GetNewOps = [&](Value *SelOp) {
SmallVector<Value *> NewOps;
for (Value *Op : GEPI.operands())
if (Op == Sel)
NewOps.push_back(SelOp);
else
NewOps.push_back(Op);
return NewOps;
};

Value *True = Sel->getTrueValue();
Value *False = Sel->getFalseValue();
SmallVector<Value *> TrueOps = GetNewOps(True);
SmallVector<Value *> FalseOps = GetNewOps(False);

IRB.SetInsertPoint(&GEPI);
SmallVector<Value *, 4> Index(GEPI.indices());
bool IsInBounds = GEPI.isInBounds();

Type *Ty = GEPI.getSourceElementType();
Value *True = Sel->getTrueValue();
Value *NTrue = IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep",
IsInBounds);

Value *False = Sel->getFalseValue();
Value *NTrue = IRB.CreateGEP(Ty, TrueOps[0], ArrayRef(TrueOps).drop_front(),
True->getName() + ".sroa.gep", IsInBounds);

Value *NFalse = IRB.CreateGEP(Ty, False, Index,
False->getName() + ".sroa.gep", IsInBounds);
Value *NFalse =
IRB.CreateGEP(Ty, FalseOps[0], ArrayRef(FalseOps).drop_front(),
False->getName() + ".sroa.gep", IsInBounds);

Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse,
Sel->getName() + ".sroa.sel");
Expand Down Expand Up @@ -4034,8 +4066,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
}

bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
if (isa<SelectInst>(GEPI.getPointerOperand()) &&
foldGEPSelect(GEPI))
if (foldGEPSelect(GEPI))
return true;

if (isa<PHINode>(GEPI.getPointerOperand()) &&
Expand Down
72 changes: 58 additions & 14 deletions llvm/test/Transforms/SROA/select-gep.ll
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,24 @@ bb:
ret i32 %load
}


; Test gep of index select unfolding on an alloca that is splittable, but not
; promotable. The allocas here will be optimized away by subsequent passes.
define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
; CHECK-LABEL: @test_select_idx_memcpy(
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr [[ALLOCA]], ptr [[P:%.*]], i64 160, i1 false)
; CHECK-NEXT: [[ALLOCA_SROA_0:%.*]] = alloca [4 x i8], align 8
; CHECK-NEXT: [[ALLOCA_SROA_2:%.*]] = alloca [20 x i8], align 4
; CHECK-NEXT: [[ALLOCA_SROA_22:%.*]] = alloca [4 x i8], align 8
; CHECK-NEXT: [[ALLOCA_SROA_3:%.*]] = alloca [132 x i8], align 4
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_0]], ptr align 1 [[P:%.*]], i64 4, i1 false)
; CHECK-NEXT: [[ALLOCA_SROA_2_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 4
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_2]], ptr align 1 [[ALLOCA_SROA_2_0_P_SROA_IDX]], i64 20, i1 false)
; CHECK-NEXT: [[ALLOCA_SROA_22_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 24
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_22]], ptr align 1 [[ALLOCA_SROA_22_0_P_SROA_IDX]], i64 4, i1 false)
; CHECK-NEXT: [[ALLOCA_SROA_3_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 28
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_3]], ptr align 1 [[ALLOCA_SROA_3_0_P_SROA_IDX]], i64 132, i1 false)
; CHECK-NEXT: [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP]], align 4
; CHECK-NEXT: [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[ALLOCA_SROA_22]], ptr [[ALLOCA_SROA_0]]
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
; CHECK-NEXT: ret i32 [[RES]]
;
%alloca = alloca [20 x i64], align 8
Expand All @@ -173,16 +183,13 @@ define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
ret i32 %res
}

; Test gep of index select unfolding on an alloca that is splittable and
; promotable.
define i32 @test_select_idx_mem2reg(i1 %c) {
; CHECK-LABEL: @test_select_idx_mem2reg(
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
; CHECK-NEXT: store i32 1, ptr [[ALLOCA]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
; CHECK-NEXT: store i32 2, ptr [[GEP1]], align 4
; CHECK-NEXT: [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
; CHECK-NEXT: ret i32 [[RES]]
; CHECK-NEXT: [[RES_SROA_SPECULATED:%.*]] = select i1 [[C]], i32 2, i32 1
; CHECK-NEXT: ret i32 [[RES_SROA_SPECULATED]]
;
%alloca = alloca [20 x i64], align 8
store i32 1, ptr %alloca
Expand All @@ -194,6 +201,9 @@ define i32 @test_select_idx_mem2reg(i1 %c) {
ret i32 %res
}

; Test gep of index select unfolding on an alloca that escaped, and as such
; is not splittable or promotable.
; FIXME: Ideally, no transform would take place in this case.
define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
; CHECK-LABEL: @test_select_idx_escaped(
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
Expand All @@ -202,8 +212,10 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
; CHECK-NEXT: store i32 2, ptr [[GEP1]], align 4
; CHECK-NEXT: [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
; CHECK-NEXT: [[DOTSROA_GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
; CHECK-NEXT: [[DOTSROA_GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 0
; CHECK-NEXT: [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[DOTSROA_GEP]], ptr [[DOTSROA_GEP1]]
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
; CHECK-NEXT: ret i32 [[RES]]
;
%alloca = alloca [20 x i64], align 8
Expand All @@ -217,6 +229,38 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
ret i32 %res
}

; FIXME: Should we allow recursive select unfolding if all the leaves are
; constants?
define i32 @test_select_idx_nested(i1 %c, i1 %c2) {
; CHECK-LABEL: @test_select_idx_nested(
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
; CHECK-NEXT: store i32 1, ptr [[ALLOCA]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 8
; CHECK-NEXT: store i32 2, ptr [[GEP1]], align 4
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
; CHECK-NEXT: store i32 3, ptr [[GEP2]], align 4
; CHECK-NEXT: [[IDX1:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
; CHECK-NEXT: [[IDX2:%.*]] = select i1 [[C2:%.*]], i64 [[IDX1]], i64 8
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX2]]
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP3]], align 4
; CHECK-NEXT: ret i32 [[RES]]
;
%alloca = alloca [20 x i64], align 8
store i32 1, ptr %alloca
%gep1 = getelementptr inbounds i8, ptr %alloca, i64 8
store i32 2, ptr %gep1
%gep2 = getelementptr inbounds i8, ptr %alloca, i64 24
store i32 3, ptr %gep2
%idx1 = select i1 %c, i64 24, i64 0
%idx2 = select i1 %c2, i64 %idx1, i64 8
%gep3 = getelementptr inbounds i8, ptr %alloca, i64 %idx2
%res = load i32, ptr %gep3, align 4
ret i32 %res
}

; The following cases involve non-constant indices and should not be
; transformed.

define i32 @test_select_idx_not_constant1(i1 %c, ptr %p, i64 %arg) {
; CHECK-LABEL: @test_select_idx_not_constant1(
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
Expand Down