Skip to content

Commit 44e21b0

Browse files
committed
BranchPredict: try_apply favors the normal BB
1 parent fe4c85d commit 44e21b0

File tree

6 files changed

+93
-12
lines changed

6 files changed

+93
-12
lines changed

include/swift/SIL/SILBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,13 @@ class SILBuilder {
555555
ArrayRef<SILValue> args, SILBasicBlock *normalBB, SILBasicBlock *errorBB,
556556
ApplyOptions options = ApplyOptions(),
557557
const GenericSpecializationInformation *specializationInfo = nullptr,
558-
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt) {
558+
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt,
559+
ProfileCounter normalCount = ProfileCounter(),
560+
ProfileCounter errorCount = ProfileCounter()) {
559561
return insertTerminator(TryApplyInst::create(
560562
getSILDebugLocation(loc), callee, subs, args, normalBB, errorBB,
561-
options, *F, specializationInfo, isolationCrossing));
563+
options, *F, specializationInfo, isolationCrossing,
564+
normalCount, errorCount));
562565
}
563566

564567
PartialApplyInst *createPartialApply(

include/swift/SIL/SILInstruction.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10968,7 +10968,8 @@ class TryApplyInstBase : public TermInst {
1096810968

1096910969
protected:
1097010970
TryApplyInstBase(SILInstructionKind valueKind, SILDebugLocation Loc,
10971-
SILBasicBlock *normalBB, SILBasicBlock *errorBB);
10971+
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
10972+
ProfileCounter normalCount, ProfileCounter errorCount);
1097210973

1097310974
public:
1097410975
SuccessorListTy getSuccessors() {
@@ -10988,6 +10989,11 @@ class TryApplyInstBase : public TermInst {
1098810989
const SILBasicBlock *getNormalBB() const { return DestBBs[NormalIdx]; }
1098910990
SILBasicBlock *getErrorBB() { return DestBBs[ErrorIdx]; }
1099010991
const SILBasicBlock *getErrorBB() const { return DestBBs[ErrorIdx]; }
10992+
10993+
/// The number of times the Normal branch was executed
10994+
ProfileCounter getNormalBBCount() const { return DestBBs[NormalIdx].getCount(); }
10995+
/// The number of times the Error branch was executed
10996+
ProfileCounter getErrorBBCount() const { return DestBBs[ErrorIdx].getCount(); }
1099110997
};
1099210998

1099310999
/// TryApplyInst - Represents the full application of a function that
@@ -11005,15 +11011,19 @@ class TryApplyInst final
1100511011
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
1100611012
ApplyOptions options,
1100711013
const GenericSpecializationInformation *specializationInfo,
11008-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11014+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11015+
ProfileCounter normalCount,
11016+
ProfileCounter errorCount);
1100911017

1101011018
static TryApplyInst *
1101111019
create(SILDebugLocation debugLoc, SILValue callee,
1101211020
SubstitutionMap substitutions, ArrayRef<SILValue> args,
1101311021
SILBasicBlock *normalBB, SILBasicBlock *errorBB, ApplyOptions options,
1101411022
SILFunction &parentFunction,
1101511023
const GenericSpecializationInformation *specializationInfo,
11016-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11024+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11025+
ProfileCounter normalCount,
11026+
ProfileCounter errorCount);
1101711027
};
1101811028

1101911029
/// DifferentiableFunctionInst - creates a `@differentiable` function-typed

lib/IRGen/IRGenSIL.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3884,9 +3884,18 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
38843884
// FIXME: Remove this when the following radar is fixed: rdar://116636601
38853885
Builder.CreatePtrToInt(errorValue, IGM.IntPtrTy);
38863886

3887+
// Emit profile metadata if available.
3888+
llvm::MDNode *Weights = nullptr;
3889+
auto NormalBBCount = tryApplyInst->getNormalBBCount();
3890+
auto ErrorBBCount = tryApplyInst->getErrorBBCount();
3891+
if (NormalBBCount || ErrorBBCount)
3892+
Weights = IGM.createProfileWeights(ErrorBBCount ? ErrorBBCount.getValue() : 0,
3893+
NormalBBCount ? NormalBBCount.getValue() : 0);
3894+
38873895
Builder.CreateCondBr(hasError,
38883896
typedErrorLoadBB ? typedErrorLoadBB : errorDest.bb,
3889-
normalDest.bb);
3897+
normalDest.bb,
3898+
Weights);
38903899

38913900
// Set up the PHI nodes on the normal edge.
38923901
unsigned firstIndex = 0;

lib/SIL/IR/SILInstructions.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -783,19 +783,24 @@ PartialApplyInst *PartialApplyInst::create(
783783
TryApplyInstBase::TryApplyInstBase(SILInstructionKind kind,
784784
SILDebugLocation loc,
785785
SILBasicBlock *normalBB,
786-
SILBasicBlock *errorBB)
787-
: TermInst(kind, loc), DestBBs{{{this, normalBB}, {this, errorBB}}} {}
786+
SILBasicBlock *errorBB,
787+
ProfileCounter normalCount,
788+
ProfileCounter errorCount)
789+
: TermInst(kind, loc), DestBBs{{{this, normalBB, normalCount},
790+
{this, errorBB, errorCount}}} {}
788791

789792
TryApplyInst::TryApplyInst(
790793
SILDebugLocation loc, SILValue callee, SILType substCalleeTy,
791794
SubstitutionMap subs, ArrayRef<SILValue> args,
792795
ArrayRef<SILValue> typeDependentOperands, SILBasicBlock *normalBB,
793796
SILBasicBlock *errorBB, ApplyOptions options,
794797
const GenericSpecializationInformation *specializationInfo,
795-
std::optional<ApplyIsolationCrossing> isolationCrossing)
798+
std::optional<ApplyIsolationCrossing> isolationCrossing,
799+
ProfileCounter normalCount,
800+
ProfileCounter errorCount)
796801
: InstructionBase(isolationCrossing, loc, callee, substCalleeTy, subs, args,
797802
typeDependentOperands, specializationInfo, normalBB,
798-
errorBB) {
803+
errorBB, normalCount, errorCount) {
799804
setApplyOptions(options);
800805
}
801806

@@ -805,19 +810,30 @@ TryApplyInst::create(SILDebugLocation loc, SILValue callee,
805810
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
806811
ApplyOptions options, SILFunction &parentFunction,
807812
const GenericSpecializationInformation *specializationInfo,
808-
std::optional<ApplyIsolationCrossing> isolationCrossing) {
813+
std::optional<ApplyIsolationCrossing> isolationCrossing,
814+
ProfileCounter normalCount,
815+
ProfileCounter errorCount) {
809816
SILType substCalleeTy = callee->getType().substGenericArgs(
810817
parentFunction.getModule(), subs,
811818
parentFunction.getTypeExpansionContext());
812819

820+
if (parentFunction.getModule().getOptions().EnableStaticBranchPrediction) {
821+
if (!normalCount && !errorCount) {
822+
// Predict that the error branch is not taken.
823+
normalCount = 999;
824+
errorCount = 0;
825+
}
826+
}
827+
813828
SmallVector<SILValue, 32> typeDependentOperands;
814829
collectTypeDependentOperands(typeDependentOperands, parentFunction,
815830
substCalleeTy.getASTType(), subs);
816831
void *buffer = allocateTrailingInst<TryApplyInst, Operand>(
817832
parentFunction, getNumAllOperands(args, typeDependentOperands));
818833
return ::new (buffer) TryApplyInst(
819834
loc, callee, substCalleeTy, subs, args, typeDependentOperands, normalBB,
820-
errorBB, options, specializationInfo, isolationCrossing);
835+
errorBB, options, specializationInfo, isolationCrossing,
836+
normalCount, errorCount);
821837
}
822838

823839
SILType DifferentiableFunctionInst::getDifferentiableFunctionType(

lib/SIL/IR/SILPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
15731573
visitApplyInstBase(AI);
15741574
*this << ", normal " << Ctx.getID(AI->getNormalBB());
15751575
*this << ", error " << Ctx.getID(AI->getErrorBB());
1576+
if (AI->getNormalBBCount())
1577+
*this << " !normal_count(" << AI->getNormalBBCount().getValue() << ")";
1578+
if (AI->getErrorBBCount())
1579+
*this << " !error_count(" << AI->getErrorBBCount().getValue() << ")";
15761580
}
15771581

15781582
void visitPartialApplyInst(PartialApplyInst *CI) {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %target-swift-frontend %s \
2+
// RUN: -enable-experimental-static-branch-prediction \
3+
// RUN: -sil-verify-all -module-name=test -emit-sil \
4+
// RUN: | %FileCheck --check-prefix CHECK-SIL %s
5+
6+
// RUN: %target-swift-frontend %s \
7+
// RUN: -enable-experimental-static-branch-prediction \
8+
// RUN: -sil-verify-all -module-name=test -emit-irgen \
9+
// RUN: | %FileCheck --check-prefix CHECK-IR %s
10+
11+
enum MyError: Error { case err }
12+
13+
func throwy1() throws {}
14+
func throwy2() throws(MyError) { }
15+
16+
// CHECK-SIL-LABEL: sil hidden @$s4test0A13TryPredictionyySbF
17+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(999) !error_count(0)
18+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(999) !error_count(0)
19+
20+
// CHECK-IR-LABEL: define hidden swiftcc void @"$s4test0A13TryPredictionyySbF"
21+
// CHECK-IR: call swiftcc void @"$s4test7throwy1yyKF"
22+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
23+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
24+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
25+
26+
// CHECK-IR: call swiftcc void @"$s4test7throwy2yyAA7MyErrorOYKF"
27+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
28+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
29+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
30+
func testTryPrediction(_ b: Bool) {
31+
do {
32+
try throwy1()
33+
try throwy2()
34+
} catch {
35+
print("hi")
36+
}
37+
}
38+
39+
// CHECK-IR: [[PREFER_FALSE]] = !{!"branch_weights", i32 1, i32 1000}

0 commit comments

Comments
 (0)