Skip to content

Commit f9f700f

Browse files
committed
[SimplifyCFG] Convert conditional load/store to masked version
1 parent 5743b28 commit f9f700f

File tree

1 file changed

+211
-5
lines changed

1 file changed

+211
-5
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 211 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ static cl::opt<bool> HoistCondStores(
131131
"simplifycfg-hoist-cond-stores", cl::Hidden, cl::init(true),
132132
cl::desc("Hoist conditional stores if an unconditional store precedes"));
133133

134+
static cl::opt<bool> HoistLoadsStoresWithCondFaulting(
135+
"simplifycfg-hoist-loads-stores-with-cond-faulting", cl::Hidden,
136+
cl::init(true),
137+
cl::desc("Hoist loads/stores if the target supports "
138+
"conditional faulting"));
139+
134140
static cl::opt<bool> MergeCondStores(
135141
"simplifycfg-merge-cond-stores", cl::Hidden, cl::init(true),
136142
cl::desc("Hoist conditional stores even if an unconditional store does not "
@@ -275,6 +281,7 @@ class SimplifyCFGOpt {
275281
bool hoistSuccIdenticalTerminatorToSwitchOrIf(
276282
Instruction *TI, Instruction *I1,
277283
SmallVectorImpl<Instruction *> &OtherSuccTIs);
284+
bool hoistLoadStoreWithCondFaultingFromSuccessors(BasicBlock *BB);
278285
bool speculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB);
279286
bool simplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond,
280287
BasicBlock *TrueBB, BasicBlock *FalseBB,
@@ -2960,6 +2967,194 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
29602967
return HaveRewritablePHIs;
29612968
}
29622969

2970+
/// Hoist load/store instructions from the conditional successor blocks up into
2971+
/// the block.
2972+
///
2973+
/// We are looking for code like the following:
2974+
/// \code
2975+
/// BB:
2976+
/// ...
2977+
/// %cond = icmp ult %x, %y
2978+
/// br i1 %cond, label %TrueBB, label %FalseBB
2979+
/// FalseBB:
2980+
/// store i32 1, ptr %q, align 4
2981+
/// ...
2982+
/// TrueBB:
2983+
/// %0 = load i32, ptr %b, align 4
2984+
/// store i32 %0, ptr %p, align 4
2985+
/// ...
2986+
/// \endcode
2987+
//
2988+
/// We are going to transform this into:
2989+
///
2990+
/// \code
2991+
/// BB:
2992+
/// ...
2993+
/// %cond = icmp ult %x, %y
2994+
/// %0 = cload i32, ptr %b, %cond
2995+
/// cstore i32 %0, ptr %p, %cond
2996+
/// cstore i32 1, ptr %q, ~%cond
2997+
/// br i1 %cond, label %TrueBB, label %FalseBB
2998+
/// FalseBB:
2999+
/// ...
3000+
/// TrueBB:
3001+
/// ...
3002+
/// \endcode
3003+
///
3004+
/// where cload/cstore is represented by intrinsic like llvm.masked.load/store,
3005+
/// e.g.
3006+
///
3007+
/// \code
3008+
/// %vcond = bitcast i1 %cond to <1 x i1>
3009+
/// %v0 = call <1 x i32> @llvm.masked.load.v1i32.p0
3010+
/// (ptr %b, i32 4, <1 x i1> %vcond, <1 x i32> poison)
3011+
/// %0 = bitcast <1 x i32> %v0 to i32
3012+
/// call void @llvm.masked.store.v1i32.p0
3013+
// (<1 x i32> %v0, ptr %p, i32 4, <1 x i1> %vcond)
3014+
/// %cond.not = xor i1 %cond, true
3015+
/// %vcond.not = bitcast i1 %cond.not to <1 x i>
3016+
/// call void @llvm.masked.store.v1i32.p0
3017+
/// (<1 x i32> <i32 1>, ptr %q, i32 4, <1x i1> %vcond.not)
3018+
/// \endcode
3019+
///
3020+
/// \returns true if any load/store is hosited.
3021+
///
3022+
/// Note that this tranform should be run
3023+
/// * before SpeculativelyExecuteBB so that the latter can have more chance.
3024+
/// * after hoistCommonCodeFromSuccessors to ensure unconditional loads/stores
3025+
/// are handled first.
3026+
bool SimplifyCFGOpt::hoistLoadStoreWithCondFaultingFromSuccessors(
3027+
BasicBlock *BB) {
3028+
if (!HoistLoadsStoresWithCondFaulting ||
3029+
!TTI.hasConditionalLoadStoreForType())
3030+
return false;
3031+
3032+
auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
3033+
if (!BI || !BI->isConditional())
3034+
return false;
3035+
3036+
BasicBlock *IfTrueBB = BI->getSuccessor(0);
3037+
BasicBlock *IfFalseBB = BI->getSuccessor(1);
3038+
3039+
// If either of the blocks has it's address taken, then we can't do this fold,
3040+
// because the code we'd hoist would no longer run when we jump into the block
3041+
// by it's address.
3042+
for (auto *Succ : {IfTrueBB, IfFalseBB})
3043+
if (Succ->hasAddressTaken())
3044+
return false;
3045+
3046+
// Not use isa<AllocaInst>(getUnderlyingObject(I.getOperand(0)) to avoid
3047+
// checking all intermediate operands dominate the branch.
3048+
auto IsLoadFromAlloca = [](const Instruction &I) {
3049+
return isa<LoadInst>(I) && isa<AllocaInst>((I.getOperand(0)));
3050+
};
3051+
3052+
// Collect hoisted loads/stores.
3053+
SmallSetVector<Instruction *, 4> HoistedInsts;
3054+
// Not hoist load/store if
3055+
// 1. target does not have corresponding conditional faulting load/store.
3056+
// 2. it's volatile or atomic.
3057+
// 3. there is a load/store that can not be hoisted in the same bb.
3058+
// 4. there is a non-load/store that's not safe to speculatively execute
3059+
// in the same bb.
3060+
// 5. any operand of it does not dominate the branch.
3061+
// 6. it's a store and a memory read is skipped.
3062+
auto HoistInstsInBB = [&](BasicBlock *BB) {
3063+
bool SkipMemoryRead = false;
3064+
// A more efficient way to check domination. An operand dominates the
3065+
// BranchInst if
3066+
// 1. it's not defined in the same bb as the instruction.
3067+
// 2. it's to be hoisted.
3068+
//
3069+
// b/c BB is only predecessor and BranchInst does not define any value.
3070+
auto OpsDominatesBranch = [&](Instruction &I) {
3071+
return llvm::all_of(I.operands(), [&](Value *Op) {
3072+
if (auto *J = dyn_cast<Instruction>(Op)) {
3073+
if (HoistedInsts.contains(J))
3074+
return true;
3075+
if (J->getParent() == I.getParent())
3076+
return false;
3077+
}
3078+
return true;
3079+
});
3080+
};
3081+
for (auto &I : *BB) {
3082+
auto *LI = dyn_cast<LoadInst>(&I);
3083+
auto *SI = dyn_cast<StoreInst>(&I);
3084+
if (LI || SI) {
3085+
bool IsSimple = (LI && LI->isSimple()) || (SI && SI->isSimple());
3086+
if (!IsSimple || !OpsDominatesBranch(I))
3087+
return false;
3088+
auto *Type = LI ? I.getType() : I.getOperand(0)->getType();
3089+
// a load from alloca is always safe.
3090+
if (!IsLoadFromAlloca(I) && !TTI.hasConditionalLoadStoreForType(Type))
3091+
return false;
3092+
// Conservative aliasing check.
3093+
if (SI && SkipMemoryRead)
3094+
return false;
3095+
HoistedInsts.insert(&I);
3096+
} else if (!I.isTerminator() && !isSafeToSpeculativelyExecute(&I))
3097+
return false;
3098+
else if (I.mayReadFromMemory())
3099+
SkipMemoryRead = true;
3100+
}
3101+
return true;
3102+
};
3103+
3104+
if (!HoistInstsInBB(IfTrueBB) || !HoistInstsInBB(IfFalseBB) ||
3105+
HoistedInsts.empty())
3106+
return false;
3107+
3108+
// Put newly added instructions before the BranchInst.
3109+
IRBuilder<> Builder(BI);
3110+
auto &Context = BB->getContext();
3111+
auto *VCondTy = FixedVectorType::get(Type::getInt1Ty(Context), 1);
3112+
auto *Cond = BI->getOperand(0);
3113+
auto *VCond = Builder.CreateBitCast(Cond, VCondTy);
3114+
Value *VCondNot = nullptr;
3115+
for (auto *I : HoistedInsts) {
3116+
// Only need to move the position for load from alloca.
3117+
if (IsLoadFromAlloca(*I)) {
3118+
I->moveBefore(BI);
3119+
continue;
3120+
}
3121+
3122+
bool InvertCond = I->getParent() == IfFalseBB;
3123+
// Construct the inverted condition if need.
3124+
if (InvertCond && !VCondNot)
3125+
VCondNot = Builder.CreateBitCast(
3126+
Builder.CreateXor(Cond, ConstantInt::getTrue(Context)), VCondTy);
3127+
3128+
auto *Mask = InvertCond ? VCondNot : VCond;
3129+
auto *Op0 = I->getOperand(0);
3130+
if (auto *LI = dyn_cast<LoadInst>(I)) {
3131+
// Load
3132+
auto *Ty = I->getType();
3133+
auto *V0 = Builder.CreateMaskedLoad(FixedVectorType::get(Ty, 1), Op0,
3134+
LI->getAlign(), Mask);
3135+
auto *S0 = Builder.CreateBitCast(V0, Ty);
3136+
V0->copyMetadata(*I);
3137+
I->replaceAllUsesWith(S0);
3138+
} else {
3139+
// Store
3140+
auto *StoredVal =
3141+
Builder.CreateBitCast(Op0, FixedVectorType::get(Op0->getType(), 1));
3142+
auto *VStore = Builder.CreateMaskedStore(
3143+
StoredVal, I->getOperand(1), cast<StoreInst>(I)->getAlign(), Mask);
3144+
VStore->copyMetadata(*I);
3145+
}
3146+
}
3147+
3148+
// Erase the hoisted instrutions in reverse order to avoid use-w/o-define
3149+
// error.
3150+
std::for_each(HoistedInsts.rbegin(), HoistedInsts.rend(), [&](auto I) {
3151+
if (!IsLoadFromAlloca(*I))
3152+
I->eraseFromParent();
3153+
});
3154+
3155+
return true;
3156+
}
3157+
29633158
static bool isProfitableToSpeculate(const BranchInst *BI, bool Invert,
29643159
const TargetTransformInfo &TTI) {
29653160
// If the branch is non-unpredictable, and is predicted to *not* branch to
@@ -7519,31 +7714,42 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
75197714
return requestResimplify();
75207715

75217716
// We have a conditional branch to two blocks that are only reachable
7522-
// from BI. We know that the condbr dominates the two blocks, so see if
7523-
// there is any identical code in the "then" and "else" blocks. If so, we
7524-
// can hoist it up to the branching block.
7717+
// from BI. We know that the condbr dominates the two blocks, so see
7718+
//
7719+
// * if there is any identical code in the "then" and "else" blocks.
7720+
// * if there is any different load/store in the "then" and "else" blocks.
7721+
//
7722+
// If so, we can hoist it up to the branching block.
75257723
if (BI->getSuccessor(0)->getSinglePredecessor()) {
75267724
if (BI->getSuccessor(1)->getSinglePredecessor()) {
75277725
if (HoistCommon && hoistCommonCodeFromSuccessors(
75287726
BI->getParent(), !Options.HoistCommonInsts))
75297727
return requestResimplify();
7728+
if (hoistLoadStoreWithCondFaultingFromSuccessors(BI->getParent()))
7729+
return requestResimplify();
75307730
} else {
75317731
// If Successor #1 has multiple preds, we may be able to conditionally
75327732
// execute Successor #0 if it branches to Successor #1.
75337733
Instruction *Succ0TI = BI->getSuccessor(0)->getTerminator();
75347734
if (Succ0TI->getNumSuccessors() == 1 &&
7535-
Succ0TI->getSuccessor(0) == BI->getSuccessor(1))
7735+
Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) {
7736+
if (hoistLoadStoreWithCondFaultingFromSuccessors(BI->getParent()))
7737+
return requestResimplify();
75367738
if (speculativelyExecuteBB(BI, BI->getSuccessor(0)))
75377739
return requestResimplify();
7740+
}
75387741
}
75397742
} else if (BI->getSuccessor(1)->getSinglePredecessor()) {
75407743
// If Successor #0 has multiple preds, we may be able to conditionally
75417744
// execute Successor #1 if it branches to Successor #0.
75427745
Instruction *Succ1TI = BI->getSuccessor(1)->getTerminator();
75437746
if (Succ1TI->getNumSuccessors() == 1 &&
7544-
Succ1TI->getSuccessor(0) == BI->getSuccessor(0))
7747+
Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) {
7748+
if (hoistLoadStoreWithCondFaultingFromSuccessors(BI->getParent()))
7749+
return requestResimplify();
75457750
if (speculativelyExecuteBB(BI, BI->getSuccessor(1)))
75467751
return requestResimplify();
7752+
}
75477753
}
75487754

75497755
// If this is a branch on something for which we know the constant value in

0 commit comments

Comments
 (0)