Skip to content

Commit 2568e52

Browse files
authored
[X86,SimplifyCFG] Support hoisting load/store with conditional faulting (Part II) (#108812)
This is a follow up of #96878 to support hoisting load/store from BBs have the same predecessor, if load/store are the only instructions and the branch is unpredictable, e.g.: ``` void test (int a, int *c, int *d) { if (a) *c = a; else *d = a; } ```
1 parent b9731a4 commit 2568e52

File tree

2 files changed

+123
-30
lines changed

2 files changed

+123
-30
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,21 +1662,43 @@ static bool areIdenticalUpToCommutativity(const Instruction *I1,
16621662
/// \endcode
16631663
///
16641664
/// So we need to turn hoisted load/store into cload/cstore.
1665+
///
1666+
/// \param BI The branch instruction.
1667+
/// \param SpeculatedConditionalLoadsStores The load/store instructions that
1668+
/// will be speculated.
1669+
/// \param Invert indicates if speculates FalseBB. Only used in triangle CFG.
16651670
static void hoistConditionalLoadsStores(
16661671
BranchInst *BI,
16671672
SmallVectorImpl<Instruction *> &SpeculatedConditionalLoadsStores,
1668-
bool Invert) {
1673+
std::optional<bool> Invert) {
16691674
auto &Context = BI->getParent()->getContext();
16701675
auto *VCondTy = FixedVectorType::get(Type::getInt1Ty(Context), 1);
16711676
auto *Cond = BI->getOperand(0);
16721677
// Construct the condition if needed.
16731678
BasicBlock *BB = BI->getParent();
1674-
IRBuilder<> Builder(SpeculatedConditionalLoadsStores.back());
1675-
Value *Mask = Builder.CreateBitCast(
1676-
Invert ? Builder.CreateXor(Cond, ConstantInt::getTrue(Context)) : Cond,
1677-
VCondTy);
1679+
IRBuilder<> Builder(
1680+
Invert.has_value() ? SpeculatedConditionalLoadsStores.back() : BI);
1681+
Value *Mask = nullptr;
1682+
Value *MaskFalse = nullptr;
1683+
Value *MaskTrue = nullptr;
1684+
if (Invert.has_value()) {
1685+
Mask = Builder.CreateBitCast(
1686+
*Invert ? Builder.CreateXor(Cond, ConstantInt::getTrue(Context)) : Cond,
1687+
VCondTy);
1688+
} else {
1689+
MaskFalse = Builder.CreateBitCast(
1690+
Builder.CreateXor(Cond, ConstantInt::getTrue(Context)), VCondTy);
1691+
MaskTrue = Builder.CreateBitCast(Cond, VCondTy);
1692+
}
1693+
auto PeekThroughBitcasts = [](Value *V) {
1694+
while (auto *BitCast = dyn_cast<BitCastInst>(V))
1695+
V = BitCast->getOperand(0);
1696+
return V;
1697+
};
16781698
for (auto *I : SpeculatedConditionalLoadsStores) {
1679-
IRBuilder<> Builder(I);
1699+
IRBuilder<> Builder(Invert.has_value() ? I : BI);
1700+
if (!Invert.has_value())
1701+
Mask = I->getParent() == BI->getSuccessor(0) ? MaskTrue : MaskFalse;
16801702
// We currently assume conditional faulting load/store is supported for
16811703
// scalar types only when creating new instructions. This can be easily
16821704
// extended for vector types in the future.
@@ -1688,12 +1710,14 @@ static void hoistConditionalLoadsStores(
16881710
auto *Ty = I->getType();
16891711
PHINode *PN = nullptr;
16901712
Value *PassThru = nullptr;
1691-
for (User *U : I->users())
1692-
if ((PN = dyn_cast<PHINode>(U))) {
1693-
PassThru = Builder.CreateBitCast(PN->getIncomingValueForBlock(BB),
1694-
FixedVectorType::get(Ty, 1));
1695-
break;
1696-
}
1713+
if (Invert.has_value())
1714+
for (User *U : I->users())
1715+
if ((PN = dyn_cast<PHINode>(U))) {
1716+
PassThru = Builder.CreateBitCast(
1717+
PeekThroughBitcasts(PN->getIncomingValueForBlock(BB)),
1718+
FixedVectorType::get(Ty, 1));
1719+
break;
1720+
}
16971721
MaskedLoadStore = Builder.CreateMaskedLoad(
16981722
FixedVectorType::get(Ty, 1), Op0, LI->getAlign(), Mask, PassThru);
16991723
Value *NewLoadStore = Builder.CreateBitCast(MaskedLoadStore, Ty);
@@ -1702,8 +1726,8 @@ static void hoistConditionalLoadsStores(
17021726
I->replaceAllUsesWith(NewLoadStore);
17031727
} else {
17041728
// Handle Store.
1705-
auto *StoredVal =
1706-
Builder.CreateBitCast(Op0, FixedVectorType::get(Op0->getType(), 1));
1729+
auto *StoredVal = Builder.CreateBitCast(
1730+
PeekThroughBitcasts(Op0), FixedVectorType::get(Op0->getType(), 1));
17071731
MaskedLoadStore = Builder.CreateMaskedStore(
17081732
StoredVal, I->getOperand(1), cast<StoreInst>(I)->getAlign(), Mask);
17091733
}
@@ -3155,7 +3179,8 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
31553179
return HaveRewritablePHIs;
31563180
}
31573181

3158-
static bool isProfitableToSpeculate(const BranchInst *BI, bool Invert,
3182+
static bool isProfitableToSpeculate(const BranchInst *BI,
3183+
std::optional<bool> Invert,
31593184
const TargetTransformInfo &TTI) {
31603185
// If the branch is non-unpredictable, and is predicted to *not* branch to
31613186
// the `then` block, then avoid speculating it.
@@ -3166,7 +3191,10 @@ static bool isProfitableToSpeculate(const BranchInst *BI, bool Invert,
31663191
if (!extractBranchWeights(*BI, TWeight, FWeight) || (TWeight + FWeight) == 0)
31673192
return true;
31683193

3169-
uint64_t EndWeight = Invert ? TWeight : FWeight;
3194+
if (!Invert.has_value())
3195+
return false;
3196+
3197+
uint64_t EndWeight = *Invert ? TWeight : FWeight;
31703198
BranchProbability BIEndProb =
31713199
BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight);
31723200
BranchProbability Likely = TTI.getPredictableBranchThreshold();
@@ -8034,6 +8062,35 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
80348062
if (HoistCommon &&
80358063
hoistCommonCodeFromSuccessors(BI, !Options.HoistCommonInsts))
80368064
return requestResimplify();
8065+
8066+
if (BI && HoistLoadsStoresWithCondFaulting &&
8067+
Options.HoistLoadsStoresWithCondFaulting &&
8068+
isProfitableToSpeculate(BI, std::nullopt, TTI)) {
8069+
SmallVector<Instruction *, 2> SpeculatedConditionalLoadsStores;
8070+
auto CanSpeculateConditionalLoadsStores = [&]() {
8071+
for (auto *Succ : successors(BB)) {
8072+
for (Instruction &I : *Succ) {
8073+
if (I.isTerminator()) {
8074+
if (I.getNumSuccessors() > 1)
8075+
return false;
8076+
continue;
8077+
} else if (!isSafeCheapLoadStore(&I, TTI) ||
8078+
SpeculatedConditionalLoadsStores.size() ==
8079+
HoistLoadsStoresWithCondFaultingThreshold) {
8080+
return false;
8081+
}
8082+
SpeculatedConditionalLoadsStores.push_back(&I);
8083+
}
8084+
}
8085+
return !SpeculatedConditionalLoadsStores.empty();
8086+
};
8087+
8088+
if (CanSpeculateConditionalLoadsStores()) {
8089+
hoistConditionalLoadsStores(BI, SpeculatedConditionalLoadsStores,
8090+
std::nullopt);
8091+
return requestResimplify();
8092+
}
8093+
}
80378094
} else {
80388095
// If Successor #1 has multiple preds, we may be able to conditionally
80398096
// execute Successor #0 if it branches to Successor #1.

llvm/test/Transforms/SimplifyCFG/X86/hoist-loads-stores-with-cf.ll

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -276,34 +276,32 @@ if.false: ; preds = %if.true, %entry
276276
}
277277

278278
;; Both of successor 0 and successor 1 have a single predecessor.
279-
;; TODO: Support transform for this case.
280-
define void @single_predecessor(ptr %p, ptr %q, i32 %a) {
279+
define i32 @single_predecessor(ptr %p, ptr %q, i32 %a) {
281280
; CHECK-LABEL: @single_predecessor(
282281
; CHECK-NEXT: entry:
283282
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp ne i32 [[A:%.*]], 0
284-
; CHECK-NEXT: br i1 [[TOBOOL]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
285-
; CHECK: common.ret:
286-
; CHECK-NEXT: ret void
287-
; CHECK: if.end:
288-
; CHECK-NEXT: store i32 1, ptr [[Q:%.*]], align 4
289-
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
290-
; CHECK: if.then:
291-
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[Q]], align 4
292-
; CHECK-NEXT: store i32 [[TMP0]], ptr [[P:%.*]], align 4
293-
; CHECK-NEXT: br label [[COMMON_RET]]
283+
; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[TOBOOL]], true
284+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i1 [[TMP0]] to <1 x i1>
285+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i1 [[TOBOOL]] to <1 x i1>
286+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> splat (i32 1), ptr [[Q:%.*]], i32 4, <1 x i1> [[TMP2]])
287+
; CHECK-NEXT: [[TMP3:%.*]] = call <1 x i32> @llvm.masked.load.v1i32.p0(ptr [[Q]], i32 4, <1 x i1> [[TMP1]], <1 x i32> poison)
288+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <1 x i32> [[TMP3]] to i32
289+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> [[TMP3]], ptr [[P:%.*]], i32 4, <1 x i1> [[TMP1]])
290+
; CHECK-NEXT: [[DOT:%.*]] = select i1 [[TOBOOL]], i32 2, i32 3
291+
; CHECK-NEXT: ret i32 [[DOT]]
294292
;
295293
entry:
296294
%tobool = icmp ne i32 %a, 0
297295
br i1 %tobool, label %if.end, label %if.then
298296

299297
if.end:
300298
store i32 1, ptr %q
301-
ret void
299+
ret i32 2
302300

303301
if.then:
304302
%0 = load i32, ptr %q
305303
store i32 %0, ptr %p
306-
ret void
304+
ret i32 3
307305
}
308306

309307
;; Hoist 6 stores.
@@ -759,6 +757,44 @@ if.true:
759757
ret i32 %res
760758
}
761759

760+
;; Not transform if either BB has multiple successors.
761+
define i32 @not_multi_successors(i1 %c1, i32 %c2, ptr %p) {
762+
; CHECK-LABEL: @not_multi_successors(
763+
; CHECK-NEXT: entry:
764+
; CHECK-NEXT: br i1 [[C1:%.*]], label [[ENTRY_IF:%.*]], label [[COMMON_RET:%.*]]
765+
; CHECK: entry.if:
766+
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[P:%.*]], align 4
767+
; CHECK-NEXT: switch i32 [[C2:%.*]], label [[COMMON_RET]] [
768+
; CHECK-NEXT: i32 0, label [[SW_BB:%.*]]
769+
; CHECK-NEXT: i32 1, label [[SW_BB]]
770+
; CHECK-NEXT: ]
771+
; CHECK: common.ret:
772+
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VAL]], [[ENTRY_IF]] ], [ 0, [[SW_BB]] ]
773+
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
774+
; CHECK: sw.bb:
775+
; CHECK-NEXT: br label [[COMMON_RET]]
776+
;
777+
entry:
778+
br i1 %c1, label %entry.if, label %entry.else
779+
780+
entry.if: ; preds = %entry
781+
%val = load i32, ptr %p, align 4
782+
switch i32 %c2, label %return [
783+
i32 0, label %sw.bb
784+
i32 1, label %sw.bb
785+
]
786+
787+
entry.else: ; preds = %entry
788+
ret i32 0
789+
790+
sw.bb: ; preds = %entry.if, %entry.if
791+
br label %return
792+
793+
return: ; preds = %sw.bb, %entry.if
794+
%ret = phi i32 [ %val, %entry.if ], [ 0, %sw.bb ]
795+
ret i32 %ret
796+
}
797+
762798
declare i32 @read_memory_only() readonly nounwind willreturn speculatable
763799

764800
!llvm.dbg.cu = !{!0}

0 commit comments

Comments
 (0)