Skip to content

Throws Prediction + HotColdSplitting #75382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/AST/IRGenOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ class IRGenOptions {

unsigned UseFragileResilientProtocolWitnesses : 1;

// Whether to run the HotColdSplitting pass when optimizing.
unsigned EnableHotColdSplit : 1;

/// The number of threads for multi-threaded code generation.
unsigned NumThreads = 0;

Expand Down Expand Up @@ -564,6 +567,7 @@ class IRGenOptions {
DisableReadonlyStaticObjects(false), CollocatedMetadataFunctions(false),
ColocateTypeDescriptors(true), UseRelativeProtocolWitnessTables(false),
UseFragileResilientProtocolWitnesses(false),
EnableHotColdSplit(false),
CmdArgs(), SanitizeCoverage(llvm::SanitizerCoverageOptions()),
TypeInfoFilter(TypeInfoDumpFilter::All),
PlatformCCallingConvention(llvm::CallingConv::C), UseCASBackend(false),
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/SILOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class SILOptions {
/// Controls whether to run async demotion pass.
bool EnableAsyncDemotion = false;

/// Controls whether to always assume that functions rarely throw an Error
/// within the optimizer. This influences static branch prediction.
bool EnableThrowsPrediction = false;

/// Should we run any SIL performance optimizations
///
/// Useful when you want to enable -O LLVM opts but not -O SIL opts.
Expand Down
12 changes: 12 additions & 0 deletions include/swift/Option/FrontendOptions.td
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ def enable_spec_devirt : Flag<["-"], "enable-spec-devirt">,
def enable_async_demotion : Flag<["-"], "enable-experimental-async-demotion">,
HelpText<"Enables an optimization pass to demote async functions.">;

def enable_throws_prediction : Flag<["-"], "enable-throws-prediction">,
HelpText<"Enables optimization assumption that functions rarely throw errors.">;
def disable_throws_prediction : Flag<["-"], "disable-throws-prediction">,
HelpText<"Disables optimization assumption that functions rarely throw errors.">;

def disable_access_control : Flag<["-"], "disable-access-control">,
HelpText<"Don't respect access control restrictions">;
def enable_access_control : Flag<["-"], "enable-access-control">,
Expand Down Expand Up @@ -1305,6 +1310,13 @@ def disable_fragile_resilient_protocol_witnesses :
Flag<["-"], "disable-fragile-relative-protocol-tables">,
HelpText<"Disable relative protocol witness tables">;

def enable_split_cold_code :
Flag<["-"], "enable-split-cold-code">,
HelpText<"Enable splitting of cold code when optimizing">;
def disable_split_cold_code :
Flag<["-"], "disable-split-cold-code">,
HelpText<"Disable splitting of cold code when optimizing">;

def enable_new_llvm_pass_manager :
Flag<["-"], "enable-new-llvm-pass-manager">,
HelpText<"Enable the new llvm pass manager">;
Expand Down
7 changes: 5 additions & 2 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,13 @@ class SILBuilder {
ArrayRef<SILValue> args, SILBasicBlock *normalBB, SILBasicBlock *errorBB,
ApplyOptions options = ApplyOptions(),
const GenericSpecializationInformation *specializationInfo = nullptr,
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt) {
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt,
ProfileCounter normalCount = ProfileCounter(),
ProfileCounter errorCount = ProfileCounter()) {
return insertTerminator(TryApplyInst::create(
getSILDebugLocation(loc), callee, subs, args, normalBB, errorBB,
options, *F, specializationInfo, isolationCrossing));
options, *F, specializationInfo, isolationCrossing,
normalCount, errorCount));
}

PartialApplyInst *createPartialApply(
Expand Down
16 changes: 13 additions & 3 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -10968,7 +10968,8 @@ class TryApplyInstBase : public TermInst {

protected:
TryApplyInstBase(SILInstructionKind valueKind, SILDebugLocation Loc,
SILBasicBlock *normalBB, SILBasicBlock *errorBB);
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
ProfileCounter normalCount, ProfileCounter errorCount);

public:
SuccessorListTy getSuccessors() {
Expand All @@ -10988,6 +10989,11 @@ class TryApplyInstBase : public TermInst {
const SILBasicBlock *getNormalBB() const { return DestBBs[NormalIdx]; }
SILBasicBlock *getErrorBB() { return DestBBs[ErrorIdx]; }
const SILBasicBlock *getErrorBB() const { return DestBBs[ErrorIdx]; }

/// The number of times the Normal branch was executed
ProfileCounter getNormalBBCount() const { return DestBBs[NormalIdx].getCount(); }
/// The number of times the Error branch was executed
ProfileCounter getErrorBBCount() const { return DestBBs[ErrorIdx].getCount(); }
};

/// TryApplyInst - Represents the full application of a function that
Expand All @@ -11005,15 +11011,19 @@ class TryApplyInst final
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
ApplyOptions options,
const GenericSpecializationInformation *specializationInfo,
std::optional<ApplyIsolationCrossing> isolationCrossing);
std::optional<ApplyIsolationCrossing> isolationCrossing,
ProfileCounter normalCount,
ProfileCounter errorCount);

static TryApplyInst *
create(SILDebugLocation debugLoc, SILValue callee,
SubstitutionMap substitutions, ArrayRef<SILValue> args,
SILBasicBlock *normalBB, SILBasicBlock *errorBB, ApplyOptions options,
SILFunction &parentFunction,
const GenericSpecializationInformation *specializationInfo,
std::optional<ApplyIsolationCrossing> isolationCrossing);
std::optional<ApplyIsolationCrossing> isolationCrossing,
ProfileCounter normalCount,
ProfileCounter errorCount);
};

/// DifferentiableFunctionInst - creates a `@differentiable` function-typed
Expand Down
5 changes: 5 additions & 0 deletions lib/DriverTool/sil_opt_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ struct SILOptOptions {
EnableAsyncDemotion = llvm::cl::opt<bool>("enable-async-demotion",
llvm::cl::desc("Enables an optimization pass to demote async functions."));

llvm::cl::opt<bool>
EnableThrowsPrediction = llvm::cl::opt<bool>("enable-throws-prediction",
llvm::cl::desc("Enables optimization assumption that functions rarely throw errors."));

llvm::cl::opt<bool>
EnableMoveInoutStackProtection = llvm::cl::opt<bool>("enable-move-inout-stack-protector",
llvm::cl::desc("Enable the stack protector by moving values to temporaries."));
Expand Down Expand Up @@ -847,6 +851,7 @@ int sil_opt_main(ArrayRef<const char *> argv, void *MainAddr) {

SILOpts.EnableSpeculativeDevirtualization = options.EnableSpeculativeDevirtualization;
SILOpts.EnableAsyncDemotion = options.EnableAsyncDemotion;
SILOpts.EnableThrowsPrediction = options.EnableThrowsPrediction;
SILOpts.IgnoreAlwaysInline = options.IgnoreAlwaysInline;
SILOpts.EnableOSSAModules = options.EnableOSSAModules;
SILOpts.EnableSILOpaqueValues = options.EnableSILOpaqueValues;
Expand Down
7 changes: 7 additions & 0 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,9 @@ static bool ParseSILArgs(SILOptions &Opts, ArgList &Args,
OPT_enable_sil_opaque_values, OPT_disable_sil_opaque_values, false);
Opts.EnableSpeculativeDevirtualization |= Args.hasArg(OPT_enable_spec_devirt);
Opts.EnableAsyncDemotion |= Args.hasArg(OPT_enable_async_demotion);
Opts.EnableThrowsPrediction = Args.hasFlag(
OPT_enable_throws_prediction, OPT_disable_throws_prediction,
Opts.EnableThrowsPrediction);
Opts.EnableActorDataRaceChecks |= Args.hasFlag(
OPT_enable_actor_data_race_checks,
OPT_disable_actor_data_race_checks, /*default=*/false);
Expand Down Expand Up @@ -3359,6 +3362,10 @@ static bool ParseIRGenArgs(IRGenOptions &Opts, ArgList &Args,
Args.hasFlag(OPT_enable_fragile_resilient_protocol_witnesses,
OPT_disable_fragile_resilient_protocol_witnesses,
Opts.UseFragileResilientProtocolWitnesses);
Opts.EnableHotColdSplit =
Args.hasFlag(OPT_enable_split_cold_code,
OPT_disable_split_cold_code,
Opts.EnableHotColdSplit);
Opts.EnableLargeLoadableTypesReg2Mem =
Args.hasFlag(OPT_enable_large_loadable_types_reg2mem,
OPT_disable_large_loadable_types_reg2mem,
Expand Down
6 changes: 6 additions & 0 deletions lib/IRGen/GenThunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ void IRGenThunk::emit() {
llvm::Value *nil = llvm::ConstantPointerNull::get(
cast<llvm::PointerType>(errorValue->getType()));
auto *hasError = IGF.Builder.CreateICmpNE(errorValue, nil);

// Predict no error is thrown.
hasError =
IGF.IGM.getSILModule().getOptions().EnableThrowsPrediction ?
IGF.Builder.CreateExpectCond(IGF.IGM, hasError, false) : hasError;

IGF.Builder.CreateCondBr(hasError, errorBB, successBB);

IGF.Builder.emitBlock(errorBB);
Expand Down
5 changes: 5 additions & 0 deletions lib/IRGen/IRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,11 @@ class IRBuilder : public IRBuilderBase {
name);
}

// Creates an @llvm.expect.i1 call, where the value should be an i1 type.
llvm::CallInst *CreateExpectCond(IRGenModule &IGM,
llvm::Value *value,
bool expectedValue, const Twine &name = "");

/// Call the trap intrinsic. If optimizations are enabled, an inline asm
/// gadget is emitted before the trap. The gadget inhibits transforms which
/// merge trap calls together, which makes debugging crashes easier.
Expand Down
4 changes: 4 additions & 0 deletions lib/IRGen/IRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
bool RunSwiftSpecificLLVMOptzns =
!Opts.DisableSwiftSpecificLLVMOptzns && !Opts.DisableLLVMOptzns;

bool DoHotColdSplit = false;
PTO.CallGraphProfile = false;

llvm::OptimizationLevel level = llvm::OptimizationLevel::O0;
Expand All @@ -221,6 +222,7 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
PTO.LoopVectorization = true;
PTO.SLPVectorization = true;
PTO.MergeFunctions = true;
DoHotColdSplit = Opts.EnableHotColdSplit;
level = llvm::OptimizationLevel::Os;
} else {
level = llvm::OptimizationLevel::O0;
Expand Down Expand Up @@ -259,6 +261,8 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
ModulePassManager MPM;

PB.setEnableHotColdSplitting(DoHotColdSplit);

if (RunSwiftSpecificLLVMOptzns) {
PB.registerScalarOptimizerLateEPCallback(
[](FunctionPassManager &FPM, OptimizationLevel Level) {
Expand Down
14 changes: 14 additions & 0 deletions lib/IRGen/IRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ Address IRGenFunction::emitAddressAtOffset(llvm::Value *base, Offset offset,
return Address(slotPtr, objectTy, objectAlignment);
}

llvm::CallInst *IRBuilder::CreateExpectCond(IRGenModule &IGM,
llvm::Value *value,
bool expectedValue,
const Twine &name) {
unsigned flag = expectedValue ? 1 : 0;
return CreateExpect(value, llvm::ConstantInt::get(IGM.Int1Ty, flag), name);
}

llvm::CallInst *IRBuilder::CreateNonMergeableTrap(IRGenModule &IGM,
StringRef failureMsg) {
if (IGM.DebugInfo && IGM.getOptions().isDebugInfoCodeView()) {
Expand Down Expand Up @@ -777,6 +785,12 @@ void IRGenFunction::emitAwaitAsyncContinuation(
auto nullError = llvm::Constant::getNullValue(errorRes->getType());
auto hasError = Builder.CreateICmpNE(errorRes, nullError);
optionalErrorResult->addIncoming(errorRes, Builder.GetInsertBlock());

// Predict no error.
hasError =
getSILModule().getOptions().EnableThrowsPrediction ?
Builder.CreateExpectCond(IGM, hasError, false) : hasError;

Builder.CreateCondBr(hasError, optionalErrorBB, normalContBB);
Builder.emitBlock(normalContBB);
}
Expand Down
17 changes: 16 additions & 1 deletion lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3884,9 +3884,18 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
// FIXME: Remove this when the following radar is fixed: rdar://116636601
Builder.CreatePtrToInt(errorValue, IGM.IntPtrTy);

// Emit profile metadata if available.
llvm::MDNode *Weights = nullptr;
auto NormalBBCount = tryApplyInst->getNormalBBCount();
auto ErrorBBCount = tryApplyInst->getErrorBBCount();
if (NormalBBCount || ErrorBBCount)
Weights = IGM.createProfileWeights(ErrorBBCount ? ErrorBBCount.getValue() : 0,
NormalBBCount ? NormalBBCount.getValue() : 0);

Builder.CreateCondBr(hasError,
typedErrorLoadBB ? typedErrorLoadBB : errorDest.bb,
normalDest.bb);
normalDest.bb,
Weights);

// Set up the PHI nodes on the normal edge.
unsigned firstIndex = 0;
Expand Down Expand Up @@ -4627,6 +4636,12 @@ void IRGenSILFunction::visitYieldInst(swift::YieldInst *i) {
// Branch to the appropriate destination.
auto unwindBB = getLoweredBB(i->getUnwindBB()).bb;
auto resumeBB = getLoweredBB(i->getResumeBB()).bb;

// Predict no unwind happens.
isUnwind =
IGM.getSILModule().getOptions().EnableThrowsPrediction ?
Builder.CreateExpectCond(IGM, isUnwind, false) : isUnwind;

Builder.CreateCondBr(isUnwind, unwindBB, resumeBB);
}

Expand Down
31 changes: 25 additions & 6 deletions lib/SIL/IR/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,19 +783,24 @@ PartialApplyInst *PartialApplyInst::create(
TryApplyInstBase::TryApplyInstBase(SILInstructionKind kind,
SILDebugLocation loc,
SILBasicBlock *normalBB,
SILBasicBlock *errorBB)
: TermInst(kind, loc), DestBBs{{{this, normalBB}, {this, errorBB}}} {}
SILBasicBlock *errorBB,
ProfileCounter normalCount,
ProfileCounter errorCount)
: TermInst(kind, loc), DestBBs{{{this, normalBB, normalCount},
{this, errorBB, errorCount}}} {}

TryApplyInst::TryApplyInst(
SILDebugLocation loc, SILValue callee, SILType substCalleeTy,
SubstitutionMap subs, ArrayRef<SILValue> args,
ArrayRef<SILValue> typeDependentOperands, SILBasicBlock *normalBB,
SILBasicBlock *errorBB, ApplyOptions options,
const GenericSpecializationInformation *specializationInfo,
std::optional<ApplyIsolationCrossing> isolationCrossing)
std::optional<ApplyIsolationCrossing> isolationCrossing,
ProfileCounter normalCount,
ProfileCounter errorCount)
: InstructionBase(isolationCrossing, loc, callee, substCalleeTy, subs, args,
typeDependentOperands, specializationInfo, normalBB,
errorBB) {
errorBB, normalCount, errorCount) {
setApplyOptions(options);
}

Expand All @@ -805,19 +810,33 @@ TryApplyInst::create(SILDebugLocation loc, SILValue callee,
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
ApplyOptions options, SILFunction &parentFunction,
const GenericSpecializationInformation *specializationInfo,
std::optional<ApplyIsolationCrossing> isolationCrossing) {
std::optional<ApplyIsolationCrossing> isolationCrossing,
ProfileCounter normalCount,
ProfileCounter errorCount) {
SILType substCalleeTy = callee->getType().substGenericArgs(
parentFunction.getModule(), subs,
parentFunction.getTypeExpansionContext());

if (parentFunction.getModule().getOptions().EnableThrowsPrediction &&
!normalCount && !errorCount) {
// Predict that the error branch is not taken.
//
// We cannot use the Expect builtin within SIL because try_apply abstracts
// over the raw conditional test to see if an error was returned.
// So, we synthesize profiling branch weights instead.
normalCount = 1999;
errorCount = 0;
}

SmallVector<SILValue, 32> typeDependentOperands;
collectTypeDependentOperands(typeDependentOperands, parentFunction,
substCalleeTy.getASTType(), subs);
void *buffer = allocateTrailingInst<TryApplyInst, Operand>(
parentFunction, getNumAllOperands(args, typeDependentOperands));
return ::new (buffer) TryApplyInst(
loc, callee, substCalleeTy, subs, args, typeDependentOperands, normalBB,
errorBB, options, specializationInfo, isolationCrossing);
errorBB, options, specializationInfo, isolationCrossing,
normalCount, errorCount);
}

SILType DifferentiableFunctionInst::getDifferentiableFunctionType(
Expand Down
4 changes: 4 additions & 0 deletions lib/SIL/IR/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
visitApplyInstBase(AI);
*this << ", normal " << Ctx.getID(AI->getNormalBB());
*this << ", error " << Ctx.getID(AI->getErrorBB());
if (AI->getNormalBBCount())
*this << " !normal_count(" << AI->getNormalBBCount().getValue() << ")";
if (AI->getErrorBBCount())
*this << " !error_count(" << AI->getErrorBBCount().getValue() << ")";
}

void visitPartialApplyInst(PartialApplyInst *CI) {
Expand Down
37 changes: 37 additions & 0 deletions test/IRGen/cold_split.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: %target-swift-frontend %s -module-name=test -emit-assembly \
// RUN: -enable-throws-prediction -O -enable-split-cold-code \
// RUN: | %FileCheck --check-prefix CHECK-ENABLED %s

//// Test disabling just the pass doesn't yield a split.
// RUN: %target-swift-frontend %s -module-name=test -emit-assembly \
// RUN: -enable-throws-prediction -O -disable-split-cold-code \
// RUN: | %FileCheck --check-prefix CHECK-DISABLED %s

//// Test disabling optimization entirely doesn't yield a split.
// RUN: %target-swift-frontend %s -module-name=test -emit-assembly \
// RUN: -enable-throws-prediction -enable-split-cold-code \
// RUN: | %FileCheck --check-prefix CHECK-DISABLED %s


// CHECK-ENABLED: cold

// CHECK-DISABLED-NOT: cold

enum MyError: Error { case err }

func getRandom(_ b: Bool) throws -> Int {
if b {
return Int.random(in: 0..<1024)
} else {
throw MyError.err
}
}

public func numberWithLogging(_ b: Bool) -> Int {
do {
return try getRandom(b)
} catch {
print("Log: random number generator failed with b=\(b)")
return 1337
}
}
Loading