@@ -3956,11 +3956,11 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
3956
3956
return false ;
3957
3957
}
3958
3958
3959
- // Fold gep (select cond, ptr1, ptr2), idx
3959
+ // Unfold gep (select cond, ptr1, ptr2), idx
3960
3960
// => select cond, gep(ptr1, idx), gep(ptr2, idx)
3961
3961
// and gep ptr, (select cond, idx1, idx2)
3962
3962
// => select cond, gep(ptr, idx1), gep(ptr, idx2)
3963
- bool foldGEPSelect (GetElementPtrInst &GEPI) {
3963
+ bool unfoldGEPSelect (GetElementPtrInst &GEPI) {
3964
3964
// Check whether the GEP has exactly one select operand and all indices
3965
3965
// will become constant after the transform.
3966
3966
SelectInst *Sel = dyn_cast<SelectInst>(GEPI.getPointerOperand ());
@@ -4029,67 +4029,104 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
4029
4029
return true ;
4030
4030
}
4031
4031
4032
- // Fold gep (phi ptr1, ptr2) => phi gep(ptr1), gep(ptr2)
4033
- bool foldGEPPhi (GetElementPtrInst &GEPI) {
4034
- if (!GEPI.hasAllConstantIndices ())
4035
- return false ;
4032
+ // Unfold gep (phi ptr1, ptr2), idx
4033
+ // => phi ((gep ptr1, idx), (gep ptr2, idx))
4034
+ // and gep ptr, (phi idx1, idx2)
4035
+ // => phi ((gep ptr, idx1), (gep ptr, idx2))
4036
+ bool unfoldGEPPhi (GetElementPtrInst &GEPI) {
4037
+ // To prevent infinitely expanding recursive phis, bail if the GEP pointer
4038
+ // operand (looking through the phi if it is the phi we want to unfold) is
4039
+ // an instruction besides an alloca.
4040
+ PHINode *Phi = dyn_cast<PHINode>(GEPI.getPointerOperand ());
4041
+ auto IsInvalidPointerOperand = [](Value *V) {
4042
+ return isa<Instruction>(V) && !isa<AllocaInst>(V);
4043
+ };
4044
+ if (Phi) {
4045
+ if (any_of (Phi->operands (), IsInvalidPointerOperand))
4046
+ return false ;
4047
+ } else {
4048
+ if (IsInvalidPointerOperand (GEPI.getPointerOperand ()))
4049
+ return false ;
4050
+ }
4051
+ // Check whether the GEP has exactly one phi operand (including the pointer
4052
+ // operand) and all indices will become constant after the transform.
4053
+ for (Value *Op : GEPI.indices ()) {
4054
+ if (auto *SI = dyn_cast<PHINode>(Op)) {
4055
+ if (Phi)
4056
+ return false ;
4057
+
4058
+ Phi = SI;
4059
+ if (!all_of (Phi->incoming_values (),
4060
+ [](Value *V) { return isa<ConstantInt>(V); }))
4061
+ return false ;
4062
+ continue ;
4063
+ }
4036
4064
4037
- PHINode *PHI = cast<PHINode>(GEPI.getPointerOperand ());
4038
- if (GEPI.getParent () != PHI->getParent () ||
4039
- llvm::any_of (PHI->incoming_values (), [](Value *In) {
4040
- Instruction *I = dyn_cast<Instruction>(In);
4041
- return !I || isa<GetElementPtrInst>(I) || isa<PHINode>(I) ||
4042
- succ_empty (I->getParent ()) ||
4043
- !I->getParent ()->isLegalToHoistInto ();
4044
- }))
4065
+ if (!isa<ConstantInt>(Op))
4066
+ return false ;
4067
+ }
4068
+
4069
+ if (!Phi)
4045
4070
return false ;
4046
4071
4047
4072
LLVM_DEBUG (dbgs () << " Rewriting gep(phi) -> phi(gep):\n " ;
4048
- dbgs () << " original: " << *PHI << " \n " ;
4073
+ dbgs () << " original: " << *Phi << " \n " ;
4049
4074
dbgs () << " " << GEPI << " \n " ;);
4050
4075
4051
- SmallVector<Value *, 4 > Index (GEPI.indices ());
4076
+ auto GetNewOps = [&](Value *PhiOp) {
4077
+ SmallVector<Value *> NewOps;
4078
+ for (Value *Op : GEPI.operands ())
4079
+ if (Op == Phi)
4080
+ NewOps.push_back (PhiOp);
4081
+ else
4082
+ NewOps.push_back (Op);
4083
+ return NewOps;
4084
+ };
4085
+
4086
+ IRB.SetInsertPoint (Phi);
4087
+ PHINode *NewPhi = IRB.CreatePHI (GEPI.getType (), Phi->getNumIncomingValues (),
4088
+ Phi->getName () + " .sroa.phi" );
4089
+
4052
4090
bool IsInBounds = GEPI.isInBounds ();
4053
- IRB.SetInsertPoint (GEPI.getParent (), GEPI.getParent ()->getFirstNonPHIIt ());
4054
- PHINode *NewPN = IRB.CreatePHI (GEPI.getType (), PHI->getNumIncomingValues (),
4055
- PHI->getName () + " .sroa.phi" );
4056
- for (unsigned I = 0 , E = PHI->getNumIncomingValues (); I != E; ++I) {
4057
- BasicBlock *B = PHI->getIncomingBlock (I);
4058
- Value *NewVal = nullptr ;
4059
- int Idx = NewPN->getBasicBlockIndex (B);
4060
- if (Idx >= 0 ) {
4061
- NewVal = NewPN->getIncomingValue (Idx);
4091
+ Type *SourceTy = GEPI.getSourceElementType ();
4092
+ // We only handle arguments, constants, and static allocas here, so we can
4093
+ // insert GEPs at the end of the entry block.
4094
+ IRB.SetInsertPoint (GEPI.getFunction ()->getEntryBlock ().getTerminator ());
4095
+ for (unsigned I = 0 , E = Phi->getNumIncomingValues (); I != E; ++I) {
4096
+ Value *Op = Phi->getIncomingValue (I);
4097
+ BasicBlock *BB = Phi->getIncomingBlock (I);
4098
+ Value *NewGEP;
4099
+ if (int NI = NewPhi->getBasicBlockIndex (BB); NI >= 0 ) {
4100
+ NewGEP = NewPhi->getIncomingValue (NI);
4062
4101
} else {
4063
- Instruction *In = cast<Instruction>(PHI->getIncomingValue (I));
4064
-
4065
- IRB.SetInsertPoint (In->getParent (), std::next (In->getIterator ()));
4066
- Type *Ty = GEPI.getSourceElementType ();
4067
- NewVal = IRB.CreateGEP (Ty, In, Index, In->getName () + " .sroa.gep" ,
4068
- IsInBounds);
4102
+ SmallVector<Value *> NewOps = GetNewOps (Op);
4103
+ NewGEP =
4104
+ IRB.CreateGEP (SourceTy, NewOps[0 ], ArrayRef (NewOps).drop_front (),
4105
+ Phi->getName () + " .sroa.gep" , IsInBounds);
4069
4106
}
4070
- NewPN ->addIncoming (NewVal, B );
4107
+ NewPhi ->addIncoming (NewGEP, BB );
4071
4108
}
4072
4109
4073
4110
Visited.erase (&GEPI);
4074
- GEPI.replaceAllUsesWith (NewPN );
4111
+ GEPI.replaceAllUsesWith (NewPhi );
4075
4112
GEPI.eraseFromParent ();
4076
- Visited.insert (NewPN );
4077
- enqueueUsers (*NewPN );
4113
+ Visited.insert (NewPhi );
4114
+ enqueueUsers (*NewPhi );
4078
4115
4079
4116
LLVM_DEBUG (dbgs () << " to: " ;
4080
4117
for (Value *In
4081
- : NewPN ->incoming_values ()) dbgs ()
4118
+ : NewPhi ->incoming_values ()) dbgs ()
4082
4119
<< " \n " << *In;
4083
- dbgs () << " \n " << *NewPN << ' \n ' );
4120
+ dbgs () << " \n " << *NewPhi << ' \n ' );
4084
4121
4085
4122
return true ;
4086
4123
}
4087
4124
4088
4125
bool visitGetElementPtrInst (GetElementPtrInst &GEPI) {
4089
- if (foldGEPSelect (GEPI))
4126
+ if (unfoldGEPSelect (GEPI))
4090
4127
return true ;
4091
4128
4092
- if (isa<PHINode>(GEPI. getPointerOperand ()) && foldGEPPhi (GEPI))
4129
+ if (unfoldGEPPhi (GEPI))
4093
4130
return true ;
4094
4131
4095
4132
enqueueUsers (GEPI);
0 commit comments