Skip to content

Commit 2f8e37d

Browse files
authored
[SROA] Unfold gep of index select (#80983)
SROA currently supports converting a gep of select into select of gep if the select is in the pointer operand. This patch expands support to selects in an index operand. This is intended to address the regression reported in #68882 (comment).
1 parent b9079ba commit 2f8e37d

File tree

2 files changed

+103
-28
lines changed

2 files changed

+103
-28
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,30 +3942,62 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
39423942
return false;
39433943
}
39443944

3945-
// Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2)
3945+
// Fold gep (select cond, ptr1, ptr2), idx
3946+
// => select cond, gep(ptr1, idx), gep(ptr2, idx)
3947+
// and gep ptr, (select cond, idx1, idx2)
3948+
// => select cond, gep(ptr, idx1), gep(ptr, idx2)
39463949
bool foldGEPSelect(GetElementPtrInst &GEPI) {
3947-
if (!GEPI.hasAllConstantIndices())
3948-
return false;
3950+
// Check whether the GEP has exactly one select operand and all indices
3951+
// will become constant after the transform.
3952+
SelectInst *Sel = dyn_cast<SelectInst>(GEPI.getPointerOperand());
3953+
for (Value *Op : GEPI.indices()) {
3954+
if (auto *SI = dyn_cast<SelectInst>(Op)) {
3955+
if (Sel)
3956+
return false;
3957+
3958+
Sel = SI;
3959+
if (!isa<ConstantInt>(Sel->getTrueValue()) ||
3960+
!isa<ConstantInt>(Sel->getFalseValue()))
3961+
return false;
3962+
continue;
3963+
}
39493964

3950-
SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand());
3965+
if (!isa<ConstantInt>(Op))
3966+
return false;
3967+
}
3968+
3969+
if (!Sel)
3970+
return false;
39513971

39523972
LLVM_DEBUG(dbgs() << " Rewriting gep(select) -> select(gep):"
39533973
<< "\n original: " << *Sel
39543974
<< "\n " << GEPI);
39553975

3976+
auto GetNewOps = [&](Value *SelOp) {
3977+
SmallVector<Value *> NewOps;
3978+
for (Value *Op : GEPI.operands())
3979+
if (Op == Sel)
3980+
NewOps.push_back(SelOp);
3981+
else
3982+
NewOps.push_back(Op);
3983+
return NewOps;
3984+
};
3985+
3986+
Value *True = Sel->getTrueValue();
3987+
Value *False = Sel->getFalseValue();
3988+
SmallVector<Value *> TrueOps = GetNewOps(True);
3989+
SmallVector<Value *> FalseOps = GetNewOps(False);
3990+
39563991
IRB.SetInsertPoint(&GEPI);
3957-
SmallVector<Value *, 4> Index(GEPI.indices());
39583992
bool IsInBounds = GEPI.isInBounds();
39593993

39603994
Type *Ty = GEPI.getSourceElementType();
3961-
Value *True = Sel->getTrueValue();
3962-
Value *NTrue = IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep",
3963-
IsInBounds);
3964-
3965-
Value *False = Sel->getFalseValue();
3995+
Value *NTrue = IRB.CreateGEP(Ty, TrueOps[0], ArrayRef(TrueOps).drop_front(),
3996+
True->getName() + ".sroa.gep", IsInBounds);
39663997

3967-
Value *NFalse = IRB.CreateGEP(Ty, False, Index,
3968-
False->getName() + ".sroa.gep", IsInBounds);
3998+
Value *NFalse =
3999+
IRB.CreateGEP(Ty, FalseOps[0], ArrayRef(FalseOps).drop_front(),
4000+
False->getName() + ".sroa.gep", IsInBounds);
39694001

39704002
Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse,
39714003
Sel->getName() + ".sroa.sel");
@@ -4039,8 +4071,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
40394071
}
40404072

40414073
bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
4042-
if (isa<SelectInst>(GEPI.getPointerOperand()) &&
4043-
foldGEPSelect(GEPI))
4074+
if (foldGEPSelect(GEPI))
40444075
return true;
40454076

40464077
if (isa<PHINode>(GEPI.getPointerOperand()) &&

llvm/test/Transforms/SROA/select-gep.ll

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,24 @@ bb:
155155
ret i32 %load
156156
}
157157

158-
158+
; Test gep of index select unfolding on an alloca that is splittable, but not
159+
; promotable. The allocas here will be optimized away by subsequent passes.
159160
define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
160161
; CHECK-LABEL: @test_select_idx_memcpy(
161-
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
162-
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr [[ALLOCA]], ptr [[P:%.*]], i64 160, i1 false)
162+
; CHECK-NEXT: [[ALLOCA_SROA_0:%.*]] = alloca [4 x i8], align 8
163+
; CHECK-NEXT: [[ALLOCA_SROA_2:%.*]] = alloca [20 x i8], align 4
164+
; CHECK-NEXT: [[ALLOCA_SROA_22:%.*]] = alloca [4 x i8], align 8
165+
; CHECK-NEXT: [[ALLOCA_SROA_3:%.*]] = alloca [132 x i8], align 4
166+
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_0]], ptr align 1 [[P:%.*]], i64 4, i1 false)
167+
; CHECK-NEXT: [[ALLOCA_SROA_2_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 4
168+
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_2]], ptr align 1 [[ALLOCA_SROA_2_0_P_SROA_IDX]], i64 20, i1 false)
169+
; CHECK-NEXT: [[ALLOCA_SROA_22_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 24
170+
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_22]], ptr align 1 [[ALLOCA_SROA_22_0_P_SROA_IDX]], i64 4, i1 false)
171+
; CHECK-NEXT: [[ALLOCA_SROA_3_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 28
172+
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_3]], ptr align 1 [[ALLOCA_SROA_3_0_P_SROA_IDX]], i64 132, i1 false)
163173
; CHECK-NEXT: [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
164-
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
165-
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP]], align 4
174+
; CHECK-NEXT: [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[ALLOCA_SROA_22]], ptr [[ALLOCA_SROA_0]]
175+
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
166176
; CHECK-NEXT: ret i32 [[RES]]
167177
;
168178
%alloca = alloca [20 x i64], align 8
@@ -173,16 +183,13 @@ define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
173183
ret i32 %res
174184
}
175185

186+
; Test gep of index select unfolding on an alloca that is splittable and
187+
; promotable.
176188
define i32 @test_select_idx_mem2reg(i1 %c) {
177189
; CHECK-LABEL: @test_select_idx_mem2reg(
178-
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
179-
; CHECK-NEXT: store i32 1, ptr [[ALLOCA]], align 4
180-
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
181-
; CHECK-NEXT: store i32 2, ptr [[GEP1]], align 4
182190
; CHECK-NEXT: [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
183-
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
184-
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
185-
; CHECK-NEXT: ret i32 [[RES]]
191+
; CHECK-NEXT: [[RES_SROA_SPECULATED:%.*]] = select i1 [[C]], i32 2, i32 1
192+
; CHECK-NEXT: ret i32 [[RES_SROA_SPECULATED]]
186193
;
187194
%alloca = alloca [20 x i64], align 8
188195
store i32 1, ptr %alloca
@@ -194,6 +201,9 @@ define i32 @test_select_idx_mem2reg(i1 %c) {
194201
ret i32 %res
195202
}
196203

204+
; Test gep of index select unfolding on an alloca that escaped, and as such
205+
; is not splittable or promotable.
206+
; FIXME: Ideally, no transform would take place in this case.
197207
define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
198208
; CHECK-LABEL: @test_select_idx_escaped(
199209
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
@@ -202,8 +212,10 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
202212
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
203213
; CHECK-NEXT: store i32 2, ptr [[GEP1]], align 4
204214
; CHECK-NEXT: [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
205-
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
206-
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
215+
; CHECK-NEXT: [[DOTSROA_GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
216+
; CHECK-NEXT: [[DOTSROA_GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 0
217+
; CHECK-NEXT: [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[DOTSROA_GEP]], ptr [[DOTSROA_GEP1]]
218+
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
207219
; CHECK-NEXT: ret i32 [[RES]]
208220
;
209221
%alloca = alloca [20 x i64], align 8
@@ -217,6 +229,38 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
217229
ret i32 %res
218230
}
219231

232+
; FIXME: Should we allow recursive select unfolding if all the leaves are
233+
; constants?
234+
define i32 @test_select_idx_nested(i1 %c, i1 %c2) {
235+
; CHECK-LABEL: @test_select_idx_nested(
236+
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8
237+
; CHECK-NEXT: store i32 1, ptr [[ALLOCA]], align 4
238+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 8
239+
; CHECK-NEXT: store i32 2, ptr [[GEP1]], align 4
240+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
241+
; CHECK-NEXT: store i32 3, ptr [[GEP2]], align 4
242+
; CHECK-NEXT: [[IDX1:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
243+
; CHECK-NEXT: [[IDX2:%.*]] = select i1 [[C2:%.*]], i64 [[IDX1]], i64 8
244+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX2]]
245+
; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP3]], align 4
246+
; CHECK-NEXT: ret i32 [[RES]]
247+
;
248+
%alloca = alloca [20 x i64], align 8
249+
store i32 1, ptr %alloca
250+
%gep1 = getelementptr inbounds i8, ptr %alloca, i64 8
251+
store i32 2, ptr %gep1
252+
%gep2 = getelementptr inbounds i8, ptr %alloca, i64 24
253+
store i32 3, ptr %gep2
254+
%idx1 = select i1 %c, i64 24, i64 0
255+
%idx2 = select i1 %c2, i64 %idx1, i64 8
256+
%gep3 = getelementptr inbounds i8, ptr %alloca, i64 %idx2
257+
%res = load i32, ptr %gep3, align 4
258+
ret i32 %res
259+
}
260+
261+
; The following cases involve non-constant indices and should not be
262+
; transformed.
263+
220264
define i32 @test_select_idx_not_constant1(i1 %c, ptr %p, i64 %arg) {
221265
; CHECK-LABEL: @test_select_idx_not_constant1(
222266
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [20 x i64], align 8

0 commit comments

Comments
 (0)