Skip to content

Commit 9d69f2b

Browse files
authored
Merge pull request #75382 from kavon/static-branch-prediction
Throws Prediction + HotColdSplitting
2 parents 3fc034b + 11e8bb8 commit 9d69f2b

File tree

16 files changed

+283
-12
lines changed

16 files changed

+283
-12
lines changed

include/swift/AST/IRGenOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,9 @@ class IRGenOptions {
475475

476476
unsigned UseFragileResilientProtocolWitnesses : 1;
477477

478+
// Whether to run the HotColdSplitting pass when optimizing.
479+
unsigned EnableHotColdSplit : 1;
480+
478481
/// The number of threads for multi-threaded code generation.
479482
unsigned NumThreads = 0;
480483

@@ -564,6 +567,7 @@ class IRGenOptions {
564567
DisableReadonlyStaticObjects(false), CollocatedMetadataFunctions(false),
565568
ColocateTypeDescriptors(true), UseRelativeProtocolWitnessTables(false),
566569
UseFragileResilientProtocolWitnesses(false),
570+
EnableHotColdSplit(false),
567571
CmdArgs(), SanitizeCoverage(llvm::SanitizerCoverageOptions()),
568572
TypeInfoFilter(TypeInfoDumpFilter::All),
569573
PlatformCCallingConvention(llvm::CallingConv::C), UseCASBackend(false),

include/swift/AST/SILOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ class SILOptions {
118118
/// Controls whether to run async demotion pass.
119119
bool EnableAsyncDemotion = false;
120120

121+
/// Controls whether to always assume that functions rarely throw an Error
122+
/// within the optimizer. This influences static branch prediction.
123+
bool EnableThrowsPrediction = false;
124+
121125
/// Should we run any SIL performance optimizations
122126
///
123127
/// Useful when you want to enable -O LLVM opts but not -O SIL opts.

include/swift/Option/FrontendOptions.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,11 @@ def enable_spec_devirt : Flag<["-"], "enable-spec-devirt">,
540540
def enable_async_demotion : Flag<["-"], "enable-experimental-async-demotion">,
541541
HelpText<"Enables an optimization pass to demote async functions.">;
542542

543+
def enable_throws_prediction : Flag<["-"], "enable-throws-prediction">,
544+
HelpText<"Enables optimization assumption that functions rarely throw errors.">;
545+
def disable_throws_prediction : Flag<["-"], "disable-throws-prediction">,
546+
HelpText<"Disables optimization assumption that functions rarely throw errors.">;
547+
543548
def disable_access_control : Flag<["-"], "disable-access-control">,
544549
HelpText<"Don't respect access control restrictions">;
545550
def enable_access_control : Flag<["-"], "enable-access-control">,
@@ -1305,6 +1310,13 @@ def disable_fragile_resilient_protocol_witnesses :
13051310
Flag<["-"], "disable-fragile-relative-protocol-tables">,
13061311
HelpText<"Disable relative protocol witness tables">;
13071312

1313+
def enable_split_cold_code :
1314+
Flag<["-"], "enable-split-cold-code">,
1315+
HelpText<"Enable splitting of cold code when optimizing">;
1316+
def disable_split_cold_code :
1317+
Flag<["-"], "disable-split-cold-code">,
1318+
HelpText<"Disable splitting of cold code when optimizing">;
1319+
13081320
def enable_new_llvm_pass_manager :
13091321
Flag<["-"], "enable-new-llvm-pass-manager">,
13101322
HelpText<"Enable the new llvm pass manager">;

include/swift/SIL/SILBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,13 @@ class SILBuilder {
555555
ArrayRef<SILValue> args, SILBasicBlock *normalBB, SILBasicBlock *errorBB,
556556
ApplyOptions options = ApplyOptions(),
557557
const GenericSpecializationInformation *specializationInfo = nullptr,
558-
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt) {
558+
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt,
559+
ProfileCounter normalCount = ProfileCounter(),
560+
ProfileCounter errorCount = ProfileCounter()) {
559561
return insertTerminator(TryApplyInst::create(
560562
getSILDebugLocation(loc), callee, subs, args, normalBB, errorBB,
561-
options, *F, specializationInfo, isolationCrossing));
563+
options, *F, specializationInfo, isolationCrossing,
564+
normalCount, errorCount));
562565
}
563566

564567
PartialApplyInst *createPartialApply(

include/swift/SIL/SILInstruction.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10993,7 +10993,8 @@ class TryApplyInstBase : public TermInst {
1099310993

1099410994
protected:
1099510995
TryApplyInstBase(SILInstructionKind valueKind, SILDebugLocation Loc,
10996-
SILBasicBlock *normalBB, SILBasicBlock *errorBB);
10996+
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
10997+
ProfileCounter normalCount, ProfileCounter errorCount);
1099710998

1099810999
public:
1099911000
SuccessorListTy getSuccessors() {
@@ -11013,6 +11014,11 @@ class TryApplyInstBase : public TermInst {
1101311014
const SILBasicBlock *getNormalBB() const { return DestBBs[NormalIdx]; }
1101411015
SILBasicBlock *getErrorBB() { return DestBBs[ErrorIdx]; }
1101511016
const SILBasicBlock *getErrorBB() const { return DestBBs[ErrorIdx]; }
11017+
11018+
/// The number of times the Normal branch was executed
11019+
ProfileCounter getNormalBBCount() const { return DestBBs[NormalIdx].getCount(); }
11020+
/// The number of times the Error branch was executed
11021+
ProfileCounter getErrorBBCount() const { return DestBBs[ErrorIdx].getCount(); }
1101611022
};
1101711023

1101811024
/// TryApplyInst - Represents the full application of a function that
@@ -11030,15 +11036,19 @@ class TryApplyInst final
1103011036
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
1103111037
ApplyOptions options,
1103211038
const GenericSpecializationInformation *specializationInfo,
11033-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11039+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11040+
ProfileCounter normalCount,
11041+
ProfileCounter errorCount);
1103411042

1103511043
static TryApplyInst *
1103611044
create(SILDebugLocation debugLoc, SILValue callee,
1103711045
SubstitutionMap substitutions, ArrayRef<SILValue> args,
1103811046
SILBasicBlock *normalBB, SILBasicBlock *errorBB, ApplyOptions options,
1103911047
SILFunction &parentFunction,
1104011048
const GenericSpecializationInformation *specializationInfo,
11041-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11049+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11050+
ProfileCounter normalCount,
11051+
ProfileCounter errorCount);
1104211052
};
1104311053

1104411054
/// DifferentiableFunctionInst - creates a `@differentiable` function-typed

lib/DriverTool/sil_opt_main.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ struct SILOptOptions {
289289
EnableAsyncDemotion = llvm::cl::opt<bool>("enable-async-demotion",
290290
llvm::cl::desc("Enables an optimization pass to demote async functions."));
291291

292+
llvm::cl::opt<bool>
293+
EnableThrowsPrediction = llvm::cl::opt<bool>("enable-throws-prediction",
294+
llvm::cl::desc("Enables optimization assumption that functions rarely throw errors."));
295+
292296
llvm::cl::opt<bool>
293297
EnableMoveInoutStackProtection = llvm::cl::opt<bool>("enable-move-inout-stack-protector",
294298
llvm::cl::desc("Enable the stack protector by moving values to temporaries."));
@@ -847,6 +851,7 @@ int sil_opt_main(ArrayRef<const char *> argv, void *MainAddr) {
847851

848852
SILOpts.EnableSpeculativeDevirtualization = options.EnableSpeculativeDevirtualization;
849853
SILOpts.EnableAsyncDemotion = options.EnableAsyncDemotion;
854+
SILOpts.EnableThrowsPrediction = options.EnableThrowsPrediction;
850855
SILOpts.IgnoreAlwaysInline = options.IgnoreAlwaysInline;
851856
SILOpts.EnableOSSAModules = options.EnableOSSAModules;
852857
SILOpts.EnableSILOpaqueValues = options.EnableSILOpaqueValues;

lib/Frontend/CompilerInvocation.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2661,6 +2661,9 @@ static bool ParseSILArgs(SILOptions &Opts, ArgList &Args,
26612661
OPT_enable_sil_opaque_values, OPT_disable_sil_opaque_values, false);
26622662
Opts.EnableSpeculativeDevirtualization |= Args.hasArg(OPT_enable_spec_devirt);
26632663
Opts.EnableAsyncDemotion |= Args.hasArg(OPT_enable_async_demotion);
2664+
Opts.EnableThrowsPrediction = Args.hasFlag(
2665+
OPT_enable_throws_prediction, OPT_disable_throws_prediction,
2666+
Opts.EnableThrowsPrediction);
26642667
Opts.EnableActorDataRaceChecks |= Args.hasFlag(
26652668
OPT_enable_actor_data_race_checks,
26662669
OPT_disable_actor_data_race_checks, /*default=*/false);
@@ -3423,6 +3426,10 @@ static bool ParseIRGenArgs(IRGenOptions &Opts, ArgList &Args,
34233426
Args.hasFlag(OPT_enable_fragile_resilient_protocol_witnesses,
34243427
OPT_disable_fragile_resilient_protocol_witnesses,
34253428
Opts.UseFragileResilientProtocolWitnesses);
3429+
Opts.EnableHotColdSplit =
3430+
Args.hasFlag(OPT_enable_split_cold_code,
3431+
OPT_disable_split_cold_code,
3432+
Opts.EnableHotColdSplit);
34263433
Opts.EnableLargeLoadableTypesReg2Mem =
34273434
Args.hasFlag(OPT_enable_large_loadable_types_reg2mem,
34283435
OPT_disable_large_loadable_types_reg2mem,

lib/IRGen/GenThunk.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,12 @@ void IRGenThunk::emit() {
357357
llvm::Value *nil = llvm::ConstantPointerNull::get(
358358
cast<llvm::PointerType>(errorValue->getType()));
359359
auto *hasError = IGF.Builder.CreateICmpNE(errorValue, nil);
360+
361+
// Predict no error is thrown.
362+
hasError =
363+
IGF.IGM.getSILModule().getOptions().EnableThrowsPrediction ?
364+
IGF.Builder.CreateExpectCond(IGF.IGM, hasError, false) : hasError;
365+
360366
IGF.Builder.CreateCondBr(hasError, errorBB, successBB);
361367

362368
IGF.Builder.emitBlock(errorBB);

lib/IRGen/IRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,11 @@ class IRBuilder : public IRBuilderBase {
410410
name);
411411
}
412412

413+
// Creates an @llvm.expect.i1 call, where the value should be an i1 type.
414+
llvm::CallInst *CreateExpectCond(IRGenModule &IGM,
415+
llvm::Value *value,
416+
bool expectedValue, const Twine &name = "");
417+
413418
/// Call the trap intrinsic. If optimizations are enabled, an inline asm
414419
/// gadget is emitted before the trap. The gadget inhibits transforms which
415420
/// merge trap calls together, which makes debugging crashes easier.

lib/IRGen/IRGen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
211211
bool RunSwiftSpecificLLVMOptzns =
212212
!Opts.DisableSwiftSpecificLLVMOptzns && !Opts.DisableLLVMOptzns;
213213

214+
bool DoHotColdSplit = false;
214215
PTO.CallGraphProfile = false;
215216

216217
llvm::OptimizationLevel level = llvm::OptimizationLevel::O0;
@@ -221,6 +222,7 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
221222
PTO.LoopVectorization = true;
222223
PTO.SLPVectorization = true;
223224
PTO.MergeFunctions = true;
225+
DoHotColdSplit = Opts.EnableHotColdSplit;
224226
level = llvm::OptimizationLevel::Os;
225227
} else {
226228
level = llvm::OptimizationLevel::O0;
@@ -259,6 +261,8 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
259261
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
260262
ModulePassManager MPM;
261263

264+
PB.setEnableHotColdSplitting(DoHotColdSplit);
265+
262266
if (RunSwiftSpecificLLVMOptzns) {
263267
PB.registerScalarOptimizerLateEPCallback(
264268
[](FunctionPassManager &FPM, OptimizationLevel Level) {

lib/IRGen/IRGenFunction.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,14 @@ Address IRGenFunction::emitAddressAtOffset(llvm::Value *base, Offset offset,
479479
return Address(slotPtr, objectTy, objectAlignment);
480480
}
481481

482+
llvm::CallInst *IRBuilder::CreateExpectCond(IRGenModule &IGM,
483+
llvm::Value *value,
484+
bool expectedValue,
485+
const Twine &name) {
486+
unsigned flag = expectedValue ? 1 : 0;
487+
return CreateExpect(value, llvm::ConstantInt::get(IGM.Int1Ty, flag), name);
488+
}
489+
482490
llvm::CallInst *IRBuilder::CreateNonMergeableTrap(IRGenModule &IGM,
483491
StringRef failureMsg) {
484492
if (IGM.DebugInfo && IGM.getOptions().isDebugInfoCodeView()) {
@@ -810,6 +818,12 @@ void IRGenFunction::emitAwaitAsyncContinuation(
810818
auto nullError = llvm::Constant::getNullValue(errorRes->getType());
811819
auto hasError = Builder.CreateICmpNE(errorRes, nullError);
812820
optionalErrorResult->addIncoming(errorRes, Builder.GetInsertBlock());
821+
822+
// Predict no error.
823+
hasError =
824+
getSILModule().getOptions().EnableThrowsPrediction ?
825+
Builder.CreateExpectCond(IGM, hasError, false) : hasError;
826+
813827
Builder.CreateCondBr(hasError, optionalErrorBB, normalContBB);
814828
Builder.emitBlock(normalContBB);
815829
}

lib/IRGen/IRGenSIL.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3885,9 +3885,18 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
38853885
// FIXME: Remove this when the following radar is fixed: rdar://116636601
38863886
Builder.CreatePtrToInt(errorValue, IGM.IntPtrTy);
38873887

3888+
// Emit profile metadata if available.
3889+
llvm::MDNode *Weights = nullptr;
3890+
auto NormalBBCount = tryApplyInst->getNormalBBCount();
3891+
auto ErrorBBCount = tryApplyInst->getErrorBBCount();
3892+
if (NormalBBCount || ErrorBBCount)
3893+
Weights = IGM.createProfileWeights(ErrorBBCount ? ErrorBBCount.getValue() : 0,
3894+
NormalBBCount ? NormalBBCount.getValue() : 0);
3895+
38883896
Builder.CreateCondBr(hasError,
38893897
typedErrorLoadBB ? typedErrorLoadBB : errorDest.bb,
3890-
normalDest.bb);
3898+
normalDest.bb,
3899+
Weights);
38913900

38923901
// Set up the PHI nodes on the normal edge.
38933902
unsigned firstIndex = 0;
@@ -4634,6 +4643,12 @@ void IRGenSILFunction::visitYieldInst(swift::YieldInst *i) {
46344643
// Branch to the appropriate destination.
46354644
auto unwindBB = getLoweredBB(i->getUnwindBB()).bb;
46364645
auto resumeBB = getLoweredBB(i->getResumeBB()).bb;
4646+
4647+
// Predict no unwind happens.
4648+
isUnwind =
4649+
IGM.getSILModule().getOptions().EnableThrowsPrediction ?
4650+
Builder.CreateExpectCond(IGM, isUnwind, false) : isUnwind;
4651+
46374652
Builder.CreateCondBr(isUnwind, unwindBB, resumeBB);
46384653
}
46394654

lib/SIL/IR/SILInstructions.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -783,19 +783,24 @@ PartialApplyInst *PartialApplyInst::create(
783783
TryApplyInstBase::TryApplyInstBase(SILInstructionKind kind,
784784
SILDebugLocation loc,
785785
SILBasicBlock *normalBB,
786-
SILBasicBlock *errorBB)
787-
: TermInst(kind, loc), DestBBs{{{this, normalBB}, {this, errorBB}}} {}
786+
SILBasicBlock *errorBB,
787+
ProfileCounter normalCount,
788+
ProfileCounter errorCount)
789+
: TermInst(kind, loc), DestBBs{{{this, normalBB, normalCount},
790+
{this, errorBB, errorCount}}} {}
788791

789792
TryApplyInst::TryApplyInst(
790793
SILDebugLocation loc, SILValue callee, SILType substCalleeTy,
791794
SubstitutionMap subs, ArrayRef<SILValue> args,
792795
ArrayRef<SILValue> typeDependentOperands, SILBasicBlock *normalBB,
793796
SILBasicBlock *errorBB, ApplyOptions options,
794797
const GenericSpecializationInformation *specializationInfo,
795-
std::optional<ApplyIsolationCrossing> isolationCrossing)
798+
std::optional<ApplyIsolationCrossing> isolationCrossing,
799+
ProfileCounter normalCount,
800+
ProfileCounter errorCount)
796801
: InstructionBase(isolationCrossing, loc, callee, substCalleeTy, subs, args,
797802
typeDependentOperands, specializationInfo, normalBB,
798-
errorBB) {
803+
errorBB, normalCount, errorCount) {
799804
setApplyOptions(options);
800805
}
801806

@@ -805,19 +810,33 @@ TryApplyInst::create(SILDebugLocation loc, SILValue callee,
805810
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
806811
ApplyOptions options, SILFunction &parentFunction,
807812
const GenericSpecializationInformation *specializationInfo,
808-
std::optional<ApplyIsolationCrossing> isolationCrossing) {
813+
std::optional<ApplyIsolationCrossing> isolationCrossing,
814+
ProfileCounter normalCount,
815+
ProfileCounter errorCount) {
809816
SILType substCalleeTy = callee->getType().substGenericArgs(
810817
parentFunction.getModule(), subs,
811818
parentFunction.getTypeExpansionContext());
812819

820+
if (parentFunction.getModule().getOptions().EnableThrowsPrediction &&
821+
!normalCount && !errorCount) {
822+
// Predict that the error branch is not taken.
823+
//
824+
// We cannot use the Expect builtin within SIL because try_apply abstracts
825+
// over the raw conditional test to see if an error was returned.
826+
// So, we synthesize profiling branch weights instead.
827+
normalCount = 1999;
828+
errorCount = 0;
829+
}
830+
813831
SmallVector<SILValue, 32> typeDependentOperands;
814832
collectTypeDependentOperands(typeDependentOperands, parentFunction,
815833
substCalleeTy.getASTType(), subs);
816834
void *buffer = allocateTrailingInst<TryApplyInst, Operand>(
817835
parentFunction, getNumAllOperands(args, typeDependentOperands));
818836
return ::new (buffer) TryApplyInst(
819837
loc, callee, substCalleeTy, subs, args, typeDependentOperands, normalBB,
820-
errorBB, options, specializationInfo, isolationCrossing);
838+
errorBB, options, specializationInfo, isolationCrossing,
839+
normalCount, errorCount);
821840
}
822841

823842
SILType DifferentiableFunctionInst::getDifferentiableFunctionType(

lib/SIL/IR/SILPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
15731573
visitApplyInstBase(AI);
15741574
*this << ", normal " << Ctx.getID(AI->getNormalBB());
15751575
*this << ", error " << Ctx.getID(AI->getErrorBB());
1576+
if (AI->getNormalBBCount())
1577+
*this << " !normal_count(" << AI->getNormalBBCount().getValue() << ")";
1578+
if (AI->getErrorBBCount())
1579+
*this << " !error_count(" << AI->getErrorBBCount().getValue() << ")";
15761580
}
15771581

15781582
void visitPartialApplyInst(PartialApplyInst *CI) {

test/IRGen/cold_split.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-swift-frontend %s -module-name=test -emit-assembly \
2+
// RUN: -enable-throws-prediction -O -enable-split-cold-code \
3+
// RUN: | %FileCheck --check-prefix CHECK-ENABLED %s
4+
5+
//// Test disabling just the pass doesn't yield a split.
6+
// RUN: %target-swift-frontend %s -module-name=test -emit-assembly \
7+
// RUN: -enable-throws-prediction -O -disable-split-cold-code \
8+
// RUN: | %FileCheck --check-prefix CHECK-DISABLED %s
9+
10+
//// Test disabling optimization entirely doesn't yield a split.
11+
// RUN: %target-swift-frontend %s -module-name=test -emit-assembly \
12+
// RUN: -enable-throws-prediction -enable-split-cold-code \
13+
// RUN: | %FileCheck --check-prefix CHECK-DISABLED %s
14+
15+
16+
// CHECK-ENABLED: cold
17+
18+
// CHECK-DISABLED-NOT: cold
19+
20+
enum MyError: Error { case err }
21+
22+
func getRandom(_ b: Bool) throws -> Int {
23+
if b {
24+
return Int.random(in: 0..<1024)
25+
} else {
26+
throw MyError.err
27+
}
28+
}
29+
30+
public func numberWithLogging(_ b: Bool) -> Int {
31+
do {
32+
return try getRandom(b)
33+
} catch {
34+
print("Log: random number generator failed with b=\(b)")
35+
return 1337
36+
}
37+
}

0 commit comments

Comments
 (0)