Skip to content

Commit 1077be1

Browse files
committed
[X86,SimplifyCFG] Support hoisting load/store with conditional faulting (Part II)
1 parent c9e5c42 commit 1077be1

File tree

2 files changed

+72
-16
lines changed

2 files changed

+72
-16
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,18 +1663,29 @@ static bool areIdenticalUpToCommutativity(const Instruction *I1,
16631663
static void hoistConditionalLoadsStores(
16641664
BranchInst *BI,
16651665
SmallVectorImpl<Instruction *> &SpeculatedConditionalLoadsStores,
1666-
bool Invert) {
1666+
std::optional<bool> Invert) {
16671667
auto &Context = BI->getParent()->getContext();
16681668
auto *VCondTy = FixedVectorType::get(Type::getInt1Ty(Context), 1);
16691669
auto *Cond = BI->getOperand(0);
16701670
// Construct the condition if needed.
16711671
BasicBlock *BB = BI->getParent();
1672-
IRBuilder<> Builder(SpeculatedConditionalLoadsStores.back());
1673-
Value *Mask = Builder.CreateBitCast(
1674-
Invert ? Builder.CreateXor(Cond, ConstantInt::getTrue(Context)) : Cond,
1675-
VCondTy);
1672+
IRBuilder<> Builder(Invert ? SpeculatedConditionalLoadsStores.back() : BI);
1673+
Value *Mask = nullptr;
1674+
Value *Mask0 = nullptr;
1675+
Value *Mask1 = nullptr;
1676+
if (Invert) {
1677+
Mask = Builder.CreateBitCast(
1678+
*Invert ? Builder.CreateXor(Cond, ConstantInt::getTrue(Context)) : Cond,
1679+
VCondTy);
1680+
} else {
1681+
Mask0 = Builder.CreateBitCast(
1682+
Builder.CreateXor(Cond, ConstantInt::getTrue(Context)), VCondTy);
1683+
Mask1 = Builder.CreateBitCast(Cond, VCondTy);
1684+
}
16761685
for (auto *I : SpeculatedConditionalLoadsStores) {
1677-
IRBuilder<> Builder(I);
1686+
IRBuilder<> Builder(Invert ? I : BI);
1687+
if (!Invert)
1688+
Mask = I->getParent() == BI->getSuccessor(0) ? Mask1 : Mask0;
16781689
// We currently assume conditional faulting load/store is supported for
16791690
// scalar types only when creating new instructions. This can be easily
16801691
// extended for vector types in the future.
@@ -1771,6 +1782,25 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI,
17711782
if (Succ->hasAddressTaken() || !Succ->getSinglePredecessor())
17721783
return false;
17731784

1785+
auto *BI = dyn_cast<BranchInst>(TI);
1786+
if (BI && HoistLoadsStoresWithCondFaulting &&
1787+
Options.HoistLoadsStoresWithCondFaulting) {
1788+
SmallVector<Instruction *, 2> SpeculatedConditionalLoadsStores;
1789+
for (auto *Succ : successors(BB)) {
1790+
for (Instruction &I : drop_end(*Succ)) {
1791+
if (!isSafeCheapLoadStore(&I, TTI) ||
1792+
SpeculatedConditionalLoadsStores.size() ==
1793+
HoistLoadsStoresWithCondFaultingThreshold)
1794+
return false;
1795+
SpeculatedConditionalLoadsStores.push_back(&I);
1796+
}
1797+
}
1798+
1799+
if (!SpeculatedConditionalLoadsStores.empty())
1800+
hoistConditionalLoadsStores(BI, SpeculatedConditionalLoadsStores,
1801+
std::nullopt);
1802+
}
1803+
17741804
// The second of pair is a SkipFlags bitmask.
17751805
using SuccIterPair = std::pair<BasicBlock::iterator, unsigned>;
17761806
SmallVector<SuccIterPair, 8> SuccIterPairs;

llvm/test/Transforms/SimplifyCFG/X86/hoist-loads-stores-with-cf.ll

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,21 +276,19 @@ if.false: ; preds = %if.true, %entry
276276
}
277277

278278
;; Both of successor 0 and successor 1 have a single predecessor.
279-
;; TODO: Support transform for this case.
280279
define void @single_predecessor(ptr %p, ptr %q, i32 %a) {
281280
; CHECK-LABEL: @single_predecessor(
282281
; CHECK-NEXT: entry:
283282
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp ne i32 [[A:%.*]], 0
284-
; CHECK-NEXT: br i1 [[TOBOOL]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
285-
; CHECK: common.ret:
283+
; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[TOBOOL]], true
284+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i1 [[TMP0]] to <1 x i1>
285+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i1 [[TOBOOL]] to <1 x i1>
286+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> <i32 1>, ptr [[Q:%.*]], i32 4, <1 x i1> [[TMP2]])
287+
; CHECK-NEXT: [[TMP3:%.*]] = call <1 x i32> @llvm.masked.load.v1i32.p0(ptr [[Q]], i32 4, <1 x i1> [[TMP1]], <1 x i32> poison)
288+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <1 x i32> [[TMP3]] to i32
289+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast i32 [[TMP4]] to <1 x i32>
290+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> [[TMP5]], ptr [[P:%.*]], i32 4, <1 x i1> [[TMP1]])
286291
; CHECK-NEXT: ret void
287-
; CHECK: if.end:
288-
; CHECK-NEXT: store i32 1, ptr [[Q:%.*]], align 4
289-
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
290-
; CHECK: if.then:
291-
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[Q]], align 4
292-
; CHECK-NEXT: store i32 [[TMP0]], ptr [[P:%.*]], align 4
293-
; CHECK-NEXT: br label [[COMMON_RET]]
294292
;
295293
entry:
296294
%tobool = icmp ne i32 %a, 0
@@ -728,6 +726,34 @@ if.true:
728726
ret i32 %res
729727
}
730728

729+
define void @diamondCFG(i32 %a, ptr %c, ptr %d) {
730+
; CHECK-LABEL: @diamondCFG(
731+
; CHECK-NEXT: entry:
732+
; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i32 [[A:%.*]], 0
733+
; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[TOBOOL_NOT]], true
734+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i1 [[TMP0]] to <1 x i1>
735+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i1 [[TOBOOL_NOT]] to <1 x i1>
736+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> zeroinitializer, ptr [[D:%.*]], i32 4, <1 x i1> [[TMP2]])
737+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast i32 [[A]] to <1 x i32>
738+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> [[TMP3]], ptr [[C:%.*]], i32 4, <1 x i1> [[TMP1]])
739+
; CHECK-NEXT: ret void
740+
;
741+
entry:
742+
%tobool.not = icmp eq i32 %a, 0
743+
br i1 %tobool.not, label %if.else, label %if.then
744+
745+
if.then: ; preds = %entry
746+
store i32 %a, ptr %c, align 4
747+
br label %if.end
748+
749+
if.else: ; preds = %entry
750+
store i32 0, ptr %d, align 4
751+
br label %if.end
752+
753+
if.end: ; preds = %if.else, %if.then
754+
ret void
755+
}
756+
731757
declare i32 @read_memory_only() readonly nounwind willreturn speculatable
732758

733759
!llvm.dbg.cu = !{!0}

0 commit comments

Comments
 (0)