Skip to content

Commit 87c86aa

Browse files
authored
[X86,SimplifyCFG] Support hoisting load/store with conditional faulting (Part I) (#96878)
This is simplifycfg part of #95515 In this PR, we support hoisting load/store with conditional faulting in `SimplifyCFGOpt::speculativelyExecuteBB` to eliminate conditional branches. This is for cases like ``` void test (int a, int *b) { if (a) *b = a; } ``` In the following patches, we will support the hoist in `SimplifyCFGOpt::hoistCommonCodeFromSuccessors`. That is for cases like ``` void test (int a, int *c, int *d) { if (a) *c = a; else *d = a; } ```
1 parent 438ad9f commit 87c86aa

File tree

8 files changed

+913
-14
lines changed

8 files changed

+913
-14
lines changed

llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct SimplifyCFGOptions {
2727
bool ConvertSwitchToLookupTable = false;
2828
bool NeedCanonicalLoop = true;
2929
bool HoistCommonInsts = false;
30+
bool HoistLoadsStoresWithCondFaulting = false;
3031
bool SinkCommonInsts = false;
3132
bool SimplifyCondBranch = true;
3233
bool SpeculateBlocks = true;
@@ -59,6 +60,10 @@ struct SimplifyCFGOptions {
5960
HoistCommonInsts = B;
6061
return *this;
6162
}
63+
SimplifyCFGOptions &hoistLoadsStoresWithCondFaulting(bool B) {
64+
HoistLoadsStoresWithCondFaulting = B;
65+
return *this;
66+
}
6267
SimplifyCFGOptions &sinkCommonInsts(bool B) {
6368
SinkCommonInsts = B;
6469
return *this;

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,8 @@ Expected<SimplifyCFGOptions> parseSimplifyCFGOptions(StringRef Params) {
848848
Result.needCanonicalLoops(Enable);
849849
} else if (ParamName == "hoist-common-insts") {
850850
Result.hoistCommonInsts(Enable);
851+
} else if (ParamName == "hoist-loads-stores-with-cond-faulting") {
852+
Result.hoistLoadsStoresWithCondFaulting(Enable);
851853
} else if (ParamName == "sink-common-insts") {
852854
Result.sinkCommonInsts(Enable);
853855
} else if (ParamName == "speculate-unpredictables") {

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,9 +1534,11 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
15341534

15351535
// LoopSink (and other loop passes since the last simplifyCFG) might have
15361536
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
1537-
OptimizePM.addPass(SimplifyCFGPass(SimplifyCFGOptions()
1538-
.convertSwitchRangeToICmp(true)
1539-
.speculateUnpredictables(true)));
1537+
OptimizePM.addPass(
1538+
SimplifyCFGPass(SimplifyCFGOptions()
1539+
.convertSwitchRangeToICmp(true)
1540+
.speculateUnpredictables(true)
1541+
.hoistLoadsStoresWithCondFaulting(true)));
15401542

15411543
// Add the core optimizing pipeline.
15421544
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM),

llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ static cl::opt<bool> UserHoistCommonInsts(
7373
"hoist-common-insts", cl::Hidden, cl::init(false),
7474
cl::desc("hoist common instructions (default = false)"));
7575

76+
static cl::opt<bool> UserHoistLoadsStoresWithCondFaulting(
77+
"hoist-loads-stores-with-cond-faulting", cl::Hidden, cl::init(false),
78+
cl::desc("Hoist loads/stores if the target supports conditional faulting "
79+
"(default = false)"));
80+
7681
static cl::opt<bool> UserSinkCommonInsts(
7782
"sink-common-insts", cl::Hidden, cl::init(false),
7883
cl::desc("Sink common instructions (default = false)"));
@@ -326,6 +331,9 @@ static void applyCommandLineOverridesToOptions(SimplifyCFGOptions &Options) {
326331
Options.NeedCanonicalLoop = UserKeepLoops;
327332
if (UserHoistCommonInsts.getNumOccurrences())
328333
Options.HoistCommonInsts = UserHoistCommonInsts;
334+
if (UserHoistLoadsStoresWithCondFaulting.getNumOccurrences())
335+
Options.HoistLoadsStoresWithCondFaulting =
336+
UserHoistLoadsStoresWithCondFaulting;
329337
if (UserSinkCommonInsts.getNumOccurrences())
330338
Options.SinkCommonInsts = UserSinkCommonInsts;
331339
if (UserSpeculateUnpredictables.getNumOccurrences())
@@ -354,6 +362,8 @@ void SimplifyCFGPass::printPipeline(
354362
<< "switch-to-lookup;";
355363
OS << (Options.NeedCanonicalLoop ? "" : "no-") << "keep-loops;";
356364
OS << (Options.HoistCommonInsts ? "" : "no-") << "hoist-common-insts;";
365+
OS << (Options.HoistLoadsStoresWithCondFaulting ? "" : "no-")
366+
<< "hoist-loads-stores-with-cond-faulting;";
357367
OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts;";
358368
OS << (Options.SpeculateBlocks ? "" : "no-") << "speculate-blocks;";
359369
OS << (Options.SimplifyCondBranch ? "" : "no-") << "simplify-cond-branch;";

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 155 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ static cl::opt<bool>
117117
HoistCommon("simplifycfg-hoist-common", cl::Hidden, cl::init(true),
118118
cl::desc("Hoist common instructions up to the parent block"));
119119

120+
static cl::opt<bool> HoistLoadsStoresWithCondFaulting(
121+
"simplifycfg-hoist-loads-stores-with-cond-faulting", cl::Hidden,
122+
cl::init(true),
123+
cl::desc("Hoist loads/stores if the target supports "
124+
"conditional faulting"));
125+
126+
static cl::opt<unsigned> HoistLoadsStoresWithCondFaultingThreshold(
127+
"hoist-loads-stores-with-cond-faulting-threshold", cl::Hidden, cl::init(6),
128+
cl::desc("Control the maximal conditonal load/store that we are willing "
129+
"to speculatively execute to eliminate conditional branch "
130+
"(default = 6)"));
131+
120132
static cl::opt<unsigned>
121133
HoistCommonSkipLimit("simplifycfg-hoist-common-skip-limit", cl::Hidden,
122134
cl::init(20),
@@ -2986,6 +2998,25 @@ static bool isProfitableToSpeculate(const BranchInst *BI, bool Invert,
29862998
return BIEndProb < Likely;
29872999
}
29883000

3001+
static bool isSafeCheapLoadStore(const Instruction *I,
3002+
const TargetTransformInfo &TTI) {
3003+
// Not handle volatile or atomic.
3004+
if (auto *L = dyn_cast<LoadInst>(I)) {
3005+
if (!L->isSimple())
3006+
return false;
3007+
} else if (auto *S = dyn_cast<StoreInst>(I)) {
3008+
if (!S->isSimple())
3009+
return false;
3010+
} else
3011+
return false;
3012+
3013+
// llvm.masked.load/store use i32 for alignment while load/store use i64.
3014+
// That's why we have the alignment limitation.
3015+
// FIXME: Update the prototype of the intrinsics?
3016+
return TTI.hasConditionalLoadStoreForType(getLoadStoreType(I)) &&
3017+
getLoadStoreAlignment(I) < Value::MaximumAlignment;
3018+
}
3019+
29893020
/// Speculate a conditional basic block flattening the CFG.
29903021
///
29913022
/// Note that this is a very risky transform currently. Speculating
@@ -3060,6 +3091,9 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
30603091
SmallVector<Instruction *, 4> SpeculatedDbgIntrinsics;
30613092

30623093
unsigned SpeculatedInstructions = 0;
3094+
bool HoistLoadsStores = HoistLoadsStoresWithCondFaulting &&
3095+
Options.HoistLoadsStoresWithCondFaulting;
3096+
SmallVector<Instruction *, 2> SpeculatedConditionalLoadsStores;
30633097
Value *SpeculatedStoreValue = nullptr;
30643098
StoreInst *SpeculatedStore = nullptr;
30653099
EphemeralValueTracker EphTracker;
@@ -3088,22 +3122,33 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
30883122

30893123
// Only speculatively execute a single instruction (not counting the
30903124
// terminator) for now.
3091-
++SpeculatedInstructions;
3125+
bool IsSafeCheapLoadStore = HoistLoadsStores &&
3126+
isSafeCheapLoadStore(&I, TTI) &&
3127+
SpeculatedConditionalLoadsStores.size() <
3128+
HoistLoadsStoresWithCondFaultingThreshold;
3129+
// Not count load/store into cost if target supports conditional faulting
3130+
// b/c it's cheap to speculate it.
3131+
if (IsSafeCheapLoadStore)
3132+
SpeculatedConditionalLoadsStores.push_back(&I);
3133+
else
3134+
++SpeculatedInstructions;
3135+
30923136
if (SpeculatedInstructions > 1)
30933137
return false;
30943138

30953139
// Don't hoist the instruction if it's unsafe or expensive.
3096-
if (!isSafeToSpeculativelyExecute(&I) &&
3097-
!(HoistCondStores && (SpeculatedStoreValue = isSafeToSpeculateStore(
3098-
&I, BB, ThenBB, EndBB))))
3140+
if (!IsSafeCheapLoadStore && !isSafeToSpeculativelyExecute(&I) &&
3141+
!(HoistCondStores && !SpeculatedStoreValue &&
3142+
(SpeculatedStoreValue =
3143+
isSafeToSpeculateStore(&I, BB, ThenBB, EndBB))))
30993144
return false;
3100-
if (!SpeculatedStoreValue &&
3145+
if (!IsSafeCheapLoadStore && !SpeculatedStoreValue &&
31013146
computeSpeculationCost(&I, TTI) >
31023147
PHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic)
31033148
return false;
31043149

31053150
// Store the store speculation candidate.
3106-
if (SpeculatedStoreValue)
3151+
if (!SpeculatedStore && SpeculatedStoreValue)
31073152
SpeculatedStore = cast<StoreInst>(&I);
31083153

31093154
// Do not hoist the instruction if any of its operands are defined but not
@@ -3130,11 +3175,11 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
31303175

31313176
// Check that we can insert the selects and that it's not too expensive to do
31323177
// so.
3133-
bool Convert = SpeculatedStore != nullptr;
3178+
bool Convert =
3179+
SpeculatedStore != nullptr || !SpeculatedConditionalLoadsStores.empty();
31343180
InstructionCost Cost = 0;
31353181
Convert |= validateAndCostRequiredSelects(BB, ThenBB, EndBB,
3136-
SpeculatedInstructions,
3137-
Cost, TTI);
3182+
SpeculatedInstructions, Cost, TTI);
31383183
if (!Convert || Cost > Budget)
31393184
return false;
31403185

@@ -3222,6 +3267,107 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
32223267
BB->splice(BI->getIterator(), ThenBB, ThenBB->begin(),
32233268
std::prev(ThenBB->end()));
32243269

3270+
// If the target supports conditional faulting,
3271+
// we look for the following pattern:
3272+
// \code
3273+
// BB:
3274+
// ...
3275+
// %cond = icmp ult %x, %y
3276+
// br i1 %cond, label %TrueBB, label %FalseBB
3277+
// FalseBB:
3278+
// store i32 1, ptr %q, align 4
3279+
// ...
3280+
// TrueBB:
3281+
// %maskedloadstore = load i32, ptr %b, align 4
3282+
// store i32 %maskedloadstore, ptr %p, align 4
3283+
// ...
3284+
// \endcode
3285+
//
3286+
// and transform it into:
3287+
//
3288+
// \code
3289+
// BB:
3290+
// ...
3291+
// %cond = icmp ult %x, %y
3292+
// %maskedloadstore = cload i32, ptr %b, %cond
3293+
// cstore i32 %maskedloadstore, ptr %p, %cond
3294+
// cstore i32 1, ptr %q, ~%cond
3295+
// br i1 %cond, label %TrueBB, label %FalseBB
3296+
// FalseBB:
3297+
// ...
3298+
// TrueBB:
3299+
// ...
3300+
// \endcode
3301+
//
3302+
// where cload/cstore are represented by llvm.masked.load/store intrinsics,
3303+
// e.g.
3304+
//
3305+
// \code
3306+
// %vcond = bitcast i1 %cond to <1 x i1>
3307+
// %v0 = call <1 x i32> @llvm.masked.load.v1i32.p0
3308+
// (ptr %b, i32 4, <1 x i1> %vcond, <1 x i32> poison)
3309+
// %maskedloadstore = bitcast <1 x i32> %v0 to i32
3310+
// call void @llvm.masked.store.v1i32.p0
3311+
// (<1 x i32> %v0, ptr %p, i32 4, <1 x i1> %vcond)
3312+
// %cond.not = xor i1 %cond, true
3313+
// %vcond.not = bitcast i1 %cond.not to <1 x i>
3314+
// call void @llvm.masked.store.v1i32.p0
3315+
// (<1 x i32> <i32 1>, ptr %q, i32 4, <1x i1> %vcond.not)
3316+
// \endcode
3317+
//
3318+
// So we need to turn hoisted load/store into cload/cstore.
3319+
auto &Context = BI->getParent()->getContext();
3320+
auto *VCondTy = FixedVectorType::get(Type::getInt1Ty(Context), 1);
3321+
auto *Cond = BI->getOperand(0);
3322+
Value *Mask = nullptr;
3323+
// Construct the condition if needed.
3324+
if (!SpeculatedConditionalLoadsStores.empty()) {
3325+
IRBuilder<> Builder(SpeculatedConditionalLoadsStores.back());
3326+
Mask = Builder.CreateBitCast(
3327+
Invert ? Builder.CreateXor(Cond, ConstantInt::getTrue(Context)) : Cond,
3328+
VCondTy);
3329+
}
3330+
for (auto *I : SpeculatedConditionalLoadsStores) {
3331+
IRBuilder<> Builder(I);
3332+
// We currently assume conditional faulting load/store is supported for
3333+
// scalar types only when creating new instructions. This can be easily
3334+
// extended for vector types in the future.
3335+
assert(!getLoadStoreType(I)->isVectorTy() && "not implemented");
3336+
auto *Op0 = I->getOperand(0);
3337+
Instruction *MaskedLoadStore = nullptr;
3338+
if (auto *LI = dyn_cast<LoadInst>(I)) {
3339+
// Handle Load.
3340+
auto *Ty = I->getType();
3341+
MaskedLoadStore = Builder.CreateMaskedLoad(FixedVectorType::get(Ty, 1),
3342+
Op0, LI->getAlign(), Mask);
3343+
I->replaceAllUsesWith(Builder.CreateBitCast(MaskedLoadStore, Ty));
3344+
} else {
3345+
// Handle Store.
3346+
auto *StoredVal =
3347+
Builder.CreateBitCast(Op0, FixedVectorType::get(Op0->getType(), 1));
3348+
MaskedLoadStore = Builder.CreateMaskedStore(
3349+
StoredVal, I->getOperand(1), cast<StoreInst>(I)->getAlign(), Mask);
3350+
}
3351+
// For non-debug metadata, only !annotation, !range, !nonnull and !align are
3352+
// kept when hoisting (see Instruction::dropUBImplyingAttrsAndMetadata).
3353+
//
3354+
// !nonnull, !align : Not support pointer type, no need to keep.
3355+
// !range: Load type is changed from scalar to vector, but the metadata on
3356+
// vector specifies a per-element range, so the semantics stay the
3357+
// same. Keep it.
3358+
// !annotation: Not impact semantics. Keep it.
3359+
I->dropUBImplyingAttrsAndUnknownMetadata(
3360+
{LLVMContext::MD_range, LLVMContext::MD_annotation});
3361+
// FIXME: DIAssignID is not supported for masked store yet.
3362+
// (Verifier::visitDIAssignIDMetadata)
3363+
at::deleteAssignmentMarkers(I);
3364+
I->eraseMetadataIf([](unsigned MDKind, MDNode *Node) {
3365+
return Node->getMetadataID() == Metadata::DIAssignIDKind;
3366+
});
3367+
MaskedLoadStore->copyMetadata(*I);
3368+
I->eraseFromParent();
3369+
}
3370+
32253371
// Insert selects and rewrite the PHI operands.
32263372
IRBuilder<NoFolder> Builder(BI);
32273373
for (PHINode &PN : EndBB->phis()) {

llvm/test/Other/new-pm-print-pipeline.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
; RUN: opt -disable-output -disable-verify -print-pipeline-passes -passes='function(print<stack-lifetime><may>,print<stack-lifetime><must>)' < %s | FileCheck %s --match-full-lines --check-prefixes=CHECK-17
5050
; CHECK-17: function(print<stack-lifetime><may>,print<stack-lifetime><must>)
5151

52-
; RUN: opt -disable-output -disable-verify -print-pipeline-passes -passes='function(simplifycfg<bonus-inst-threshold=5;forward-switch-cond;switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts;speculate-blocks;simplify-cond-branch;speculate-unpredictables>,simplifycfg<bonus-inst-threshold=7;no-forward-switch-cond;no-switch-to-lookup;no-keep-loops;no-hoist-common-insts;no-sink-common-insts;no-speculate-blocks;no-simplify-cond-branch;no-speculate-unpredictables>)' < %s | FileCheck %s --match-full-lines --check-prefixes=CHECK-18
53-
; CHECK-18: function(simplifycfg<bonus-inst-threshold=5;forward-switch-cond;no-switch-range-to-icmp;switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts;speculate-blocks;simplify-cond-branch;speculate-unpredictables>,simplifycfg<bonus-inst-threshold=7;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;no-keep-loops;no-hoist-common-insts;no-sink-common-insts;no-speculate-blocks;no-simplify-cond-branch;no-speculate-unpredictables>)
52+
; RUN: opt -disable-output -disable-verify -print-pipeline-passes -passes='function(simplifycfg<bonus-inst-threshold=5;forward-switch-cond;switch-to-lookup;keep-loops;hoist-common-insts;hoist-loads-stores-with-cond-faulting;sink-common-insts;speculate-blocks;simplify-cond-branch;speculate-unpredictables>,simplifycfg<bonus-inst-threshold=7;no-forward-switch-cond;no-switch-to-lookup;no-keep-loops;no-hoist-common-insts;no-hoist-loads-stores-with-cond-faulting;no-sink-common-insts;no-speculate-blocks;no-simplify-cond-branch;no-speculate-unpredictables>)' < %s | FileCheck %s --match-full-lines --check-prefixes=CHECK-18
53+
; CHECK-18: function(simplifycfg<bonus-inst-threshold=5;forward-switch-cond;no-switch-range-to-icmp;switch-to-lookup;keep-loops;hoist-common-insts;hoist-loads-stores-with-cond-faulting;sink-common-insts;speculate-blocks;simplify-cond-branch;speculate-unpredictables>,simplifycfg<bonus-inst-threshold=7;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;no-keep-loops;no-hoist-common-insts;no-hoist-loads-stores-with-cond-faulting;no-sink-common-insts;no-speculate-blocks;no-simplify-cond-branch;no-speculate-unpredictables>)
5454

5555
; RUN: opt -disable-output -disable-verify -print-pipeline-passes -passes='function(loop-vectorize<no-interleave-forced-only;no-vectorize-forced-only>,loop-vectorize<interleave-forced-only;vectorize-forced-only>)' < %s | FileCheck %s --match-full-lines --check-prefixes=CHECK-19
5656
; CHECK-19: function(loop-vectorize<no-interleave-forced-only;no-vectorize-forced-only;>,loop-vectorize<interleave-forced-only;vectorize-forced-only;>)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -mtriple=x86_64 -mattr=+cf -O1 -S | FileCheck %s
3+
4+
;; Test masked.load/store.v1* is generated in simplifycfg and not falls back to branch+load/store in following passes.
5+
define void @basic(i1 %cond, ptr %b, ptr %p, ptr %q) {
6+
; CHECK-LABEL: @basic(
7+
; CHECK-NEXT: entry:
8+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i1 [[COND:%.*]] to <1 x i1>
9+
; CHECK-NEXT: [[TMP1:%.*]] = call <1 x i16> @llvm.masked.load.v1i16.p0(ptr [[P:%.*]], i32 2, <1 x i1> [[TMP0]], <1 x i16> poison)
10+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <1 x i16> [[TMP1]] to i16
11+
; CHECK-NEXT: [[TMP3:%.*]] = call <1 x i32> @llvm.masked.load.v1i32.p0(ptr [[Q:%.*]], i32 4, <1 x i1> [[TMP0]], <1 x i32> poison)
12+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <1 x i32> [[TMP3]] to i32
13+
; CHECK-NEXT: [[TMP5:%.*]] = call <1 x i64> @llvm.masked.load.v1i64.p0(ptr [[B:%.*]], i32 8, <1 x i1> [[TMP0]], <1 x i64> poison)
14+
; CHECK-NEXT: [[TMP6:%.*]] = bitcast <1 x i64> [[TMP5]] to i64
15+
; CHECK-NEXT: [[TMP7:%.*]] = bitcast i16 [[TMP2]] to <1 x i16>
16+
; CHECK-NEXT: call void @llvm.masked.store.v1i16.p0(<1 x i16> [[TMP7]], ptr [[B]], i32 2, <1 x i1> [[TMP0]])
17+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast i32 [[TMP4]] to <1 x i32>
18+
; CHECK-NEXT: call void @llvm.masked.store.v1i32.p0(<1 x i32> [[TMP8]], ptr [[P]], i32 4, <1 x i1> [[TMP0]])
19+
; CHECK-NEXT: [[TMP9:%.*]] = bitcast i64 [[TMP6]] to <1 x i64>
20+
; CHECK-NEXT: call void @llvm.masked.store.v1i64.p0(<1 x i64> [[TMP9]], ptr [[Q]], i32 8, <1 x i1> [[TMP0]])
21+
; CHECK-NEXT: ret void
22+
;
23+
entry:
24+
br i1 %cond, label %if.true, label %if.false
25+
26+
if.false:
27+
br label %if.end
28+
29+
if.true:
30+
%pv = load i16, ptr %p, align 2
31+
%qv = load i32, ptr %q, align 4
32+
%bv = load i64, ptr %b, align 8
33+
store i16 %pv, ptr %b, align 2
34+
store i32 %qv, ptr %p, align 4
35+
store i64 %bv, ptr %q, align 8
36+
br label %if.false
37+
38+
if.end:
39+
ret void
40+
}

0 commit comments

Comments
 (0)