Skip to content

Commit 0cce602

Browse files
committed
SIL: branch weights for try_apply's
1 parent d373f6f commit 0cce602

File tree

6 files changed

+179
-12
lines changed

6 files changed

+179
-12
lines changed

include/swift/SIL/SILBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,13 @@ class SILBuilder {
555555
ArrayRef<SILValue> args, SILBasicBlock *normalBB, SILBasicBlock *errorBB,
556556
ApplyOptions options = ApplyOptions(),
557557
const GenericSpecializationInformation *specializationInfo = nullptr,
558-
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt) {
558+
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt,
559+
ProfileCounter normalCount = ProfileCounter(),
560+
ProfileCounter errorCount = ProfileCounter()) {
559561
return insertTerminator(TryApplyInst::create(
560562
getSILDebugLocation(loc), callee, subs, args, normalBB, errorBB,
561-
options, *F, specializationInfo, isolationCrossing));
563+
options, *F, specializationInfo, isolationCrossing,
564+
normalCount, errorCount));
562565
}
563566

564567
PartialApplyInst *createPartialApply(

include/swift/SIL/SILInstruction.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10968,7 +10968,8 @@ class TryApplyInstBase : public TermInst {
1096810968

1096910969
protected:
1097010970
TryApplyInstBase(SILInstructionKind valueKind, SILDebugLocation Loc,
10971-
SILBasicBlock *normalBB, SILBasicBlock *errorBB);
10971+
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
10972+
ProfileCounter normalCount, ProfileCounter errorCount);
1097210973

1097310974
public:
1097410975
SuccessorListTy getSuccessors() {
@@ -10988,6 +10989,11 @@ class TryApplyInstBase : public TermInst {
1098810989
const SILBasicBlock *getNormalBB() const { return DestBBs[NormalIdx]; }
1098910990
SILBasicBlock *getErrorBB() { return DestBBs[ErrorIdx]; }
1099010991
const SILBasicBlock *getErrorBB() const { return DestBBs[ErrorIdx]; }
10992+
10993+
/// The number of times the Normal branch was executed
10994+
ProfileCounter getNormalBBCount() const { return DestBBs[NormalIdx].getCount(); }
10995+
/// The number of times the Error branch was executed
10996+
ProfileCounter getErrorBBCount() const { return DestBBs[ErrorIdx].getCount(); }
1099110997
};
1099210998

1099310999
/// TryApplyInst - Represents the full application of a function that
@@ -11005,15 +11011,19 @@ class TryApplyInst final
1100511011
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
1100611012
ApplyOptions options,
1100711013
const GenericSpecializationInformation *specializationInfo,
11008-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11014+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11015+
ProfileCounter normalCount,
11016+
ProfileCounter errorCount);
1100911017

1101011018
static TryApplyInst *
1101111019
create(SILDebugLocation debugLoc, SILValue callee,
1101211020
SubstitutionMap substitutions, ArrayRef<SILValue> args,
1101311021
SILBasicBlock *normalBB, SILBasicBlock *errorBB, ApplyOptions options,
1101411022
SILFunction &parentFunction,
1101511023
const GenericSpecializationInformation *specializationInfo,
11016-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11024+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11025+
ProfileCounter normalCount,
11026+
ProfileCounter errorCount);
1101711027
};
1101811028

1101911029
/// DifferentiableFunctionInst - creates a `@differentiable` function-typed

lib/IRGen/IRGenSIL.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3884,9 +3884,18 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
38843884
// FIXME: Remove this when the following radar is fixed: rdar://116636601
38853885
Builder.CreatePtrToInt(errorValue, IGM.IntPtrTy);
38863886

3887+
// Emit profile metadata if available.
3888+
llvm::MDNode *Weights = nullptr;
3889+
auto NormalBBCount = tryApplyInst->getNormalBBCount();
3890+
auto ErrorBBCount = tryApplyInst->getErrorBBCount();
3891+
if (NormalBBCount || ErrorBBCount)
3892+
Weights = IGM.createProfileWeights(ErrorBBCount ? ErrorBBCount.getValue() : 0,
3893+
NormalBBCount ? NormalBBCount.getValue() : 0);
3894+
38873895
Builder.CreateCondBr(hasError,
38883896
typedErrorLoadBB ? typedErrorLoadBB : errorDest.bb,
3889-
normalDest.bb);
3897+
normalDest.bb,
3898+
Weights);
38903899

38913900
// Set up the PHI nodes on the normal edge.
38923901
unsigned firstIndex = 0;

lib/SIL/IR/SILInstructions.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -783,19 +783,24 @@ PartialApplyInst *PartialApplyInst::create(
783783
TryApplyInstBase::TryApplyInstBase(SILInstructionKind kind,
784784
SILDebugLocation loc,
785785
SILBasicBlock *normalBB,
786-
SILBasicBlock *errorBB)
787-
: TermInst(kind, loc), DestBBs{{{this, normalBB}, {this, errorBB}}} {}
786+
SILBasicBlock *errorBB,
787+
ProfileCounter normalCount,
788+
ProfileCounter errorCount)
789+
: TermInst(kind, loc), DestBBs{{{this, normalBB, normalCount},
790+
{this, errorBB, errorCount}}} {}
788791

789792
TryApplyInst::TryApplyInst(
790793
SILDebugLocation loc, SILValue callee, SILType substCalleeTy,
791794
SubstitutionMap subs, ArrayRef<SILValue> args,
792795
ArrayRef<SILValue> typeDependentOperands, SILBasicBlock *normalBB,
793796
SILBasicBlock *errorBB, ApplyOptions options,
794797
const GenericSpecializationInformation *specializationInfo,
795-
std::optional<ApplyIsolationCrossing> isolationCrossing)
798+
std::optional<ApplyIsolationCrossing> isolationCrossing,
799+
ProfileCounter normalCount,
800+
ProfileCounter errorCount)
796801
: InstructionBase(isolationCrossing, loc, callee, substCalleeTy, subs, args,
797802
typeDependentOperands, specializationInfo, normalBB,
798-
errorBB) {
803+
errorBB, normalCount, errorCount) {
799804
setApplyOptions(options);
800805
}
801806

@@ -805,19 +810,33 @@ TryApplyInst::create(SILDebugLocation loc, SILValue callee,
805810
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
806811
ApplyOptions options, SILFunction &parentFunction,
807812
const GenericSpecializationInformation *specializationInfo,
808-
std::optional<ApplyIsolationCrossing> isolationCrossing) {
813+
std::optional<ApplyIsolationCrossing> isolationCrossing,
814+
ProfileCounter normalCount,
815+
ProfileCounter errorCount) {
809816
SILType substCalleeTy = callee->getType().substGenericArgs(
810817
parentFunction.getModule(), subs,
811818
parentFunction.getTypeExpansionContext());
812819

820+
if (parentFunction.getModule().getOptions().EnableThrowsPrediction &&
821+
!normalCount && !errorCount) {
822+
// Predict that the error branch is not taken.
823+
//
824+
// We cannot use the Expect builtin within SIL because try_apply abstracts
825+
// over the raw conditional test to see if an error was returned.
826+
// So, we synthesize profiling branch weights instead.
827+
normalCount = 1999;
828+
errorCount = 0;
829+
}
830+
813831
SmallVector<SILValue, 32> typeDependentOperands;
814832
collectTypeDependentOperands(typeDependentOperands, parentFunction,
815833
substCalleeTy.getASTType(), subs);
816834
void *buffer = allocateTrailingInst<TryApplyInst, Operand>(
817835
parentFunction, getNumAllOperands(args, typeDependentOperands));
818836
return ::new (buffer) TryApplyInst(
819837
loc, callee, substCalleeTy, subs, args, typeDependentOperands, normalBB,
820-
errorBB, options, specializationInfo, isolationCrossing);
838+
errorBB, options, specializationInfo, isolationCrossing,
839+
normalCount, errorCount);
821840
}
822841

823842
SILType DifferentiableFunctionInst::getDifferentiableFunctionType(

lib/SIL/IR/SILPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
15731573
visitApplyInstBase(AI);
15741574
*this << ", normal " << Ctx.getID(AI->getNormalBB());
15751575
*this << ", error " << Ctx.getID(AI->getErrorBB());
1576+
if (AI->getNormalBBCount())
1577+
*this << " !normal_count(" << AI->getNormalBBCount().getValue() << ")";
1578+
if (AI->getErrorBBCount())
1579+
*this << " !error_count(" << AI->getErrorBBCount().getValue() << ")";
15761580
}
15771581

15781582
void visitPartialApplyInst(PartialApplyInst *CI) {
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// RUN: %target-swift-frontend %s \
2+
// RUN: -disable-availability-checking \
3+
// RUN: -enable-throws-prediction \
4+
// RUN: -sil-verify-all -module-name=test -emit-sil \
5+
// RUN: | %FileCheck --check-prefix CHECK-SIL %s
6+
7+
// RUN: %target-swift-frontend %s \
8+
// RUN: -disable-availability-checking \
9+
// RUN: -enable-throws-prediction \
10+
// RUN: -sil-verify-all -module-name=test -emit-irgen \
11+
// RUN: | %FileCheck --check-prefix CHECK-IR %s
12+
13+
// RUN: %target-swift-frontend %s \
14+
// RUN: -disable-availability-checking \
15+
// RUN: -disable-throws-prediction \
16+
// RUN: -sil-verify-all -module-name=test -emit-sil \
17+
// RUN: | %FileCheck --check-prefix CHECK-DISABLED %s
18+
19+
// CHECK-DISABLED-NOT: normal_count
20+
21+
enum MyError: Error { case err }
22+
23+
func throwy1() throws {}
24+
func throwy2() throws(MyError) { }
25+
func throwy3() async throws -> Int { 0 }
26+
func throwy4() async throws(MyError) -> Int { 1 }
27+
28+
// CHECK-SIL-LABEL: sil hidden @$s4test0A13TryPredictionyySbF
29+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
30+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
31+
32+
// CHECK-IR-LABEL: define hidden swiftcc void @"$s4test0A13TryPredictionyySbF"
33+
// CHECK-IR: call swiftcc void @"$s4test7throwy1yyKF"
34+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
35+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
36+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
37+
38+
// CHECK-IR: call swiftcc void @"$s4test7throwy2yyAA7MyErrorOYKF"
39+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
40+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
41+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
42+
func testTryPrediction(_ b: Bool) {
43+
do {
44+
try throwy1()
45+
try throwy2()
46+
} catch {
47+
print("hi")
48+
}
49+
}
50+
51+
// CHECK-SIL-LABEL: sil hidden @$s4test0A21AsyncThrowsPredictionSiyYaF
52+
// CHECK-SIL: function_ref @$s4test7throwy3SiyYaKF
53+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
54+
55+
// CHECK-IR-LABEL: define hidden swifttailcc void @"$s4test0A21AsyncThrowsPredictionSiyYaF"
56+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
57+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
58+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
59+
func testAsyncThrowsPrediction() async -> Int {
60+
if let x = try? await throwy3() {
61+
return x
62+
}
63+
return 1337
64+
}
65+
66+
// CHECK-SIL-LABEL: sil hidden @$s4test0A28Async_TYPED_ThrowsPredictionSiyYaAA7MyErrorOYKF
67+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
68+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
69+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
70+
71+
// CHECK-IR-LABEL: define hidden swifttailcc void @"$s4test0A28Async_TYPED_ThrowsPredictionSiyYaAA7MyErrorOYKF"
72+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
73+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
74+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
75+
//
76+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
77+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
78+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
79+
//
80+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
81+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
82+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
83+
func testAsync_TYPED_ThrowsPrediction() async throws(MyError) -> Int {
84+
let x = try await throwy4()
85+
let y = try await throwy4()
86+
let z = try await throwy4()
87+
return x + y + z
88+
}
89+
90+
91+
func getRandom(_ b: Bool) throws -> Int {
92+
if b {
93+
return Int.random(in: 0..<1024)
94+
} else {
95+
throw MyError.err
96+
}
97+
}
98+
99+
// CHECK-SIL-LABEL: sil hidden @$s4test20sequenceOfNormalTrysySiSb_S2btKF
100+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
101+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
102+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
103+
104+
// CHECK-IR-LABEL: define hidden swiftcc i64 @"$s4test20sequenceOfNormalTrysySiSb_S2btKF"
105+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr {{%.*}}, null
106+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
107+
//
108+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr {{%.*}}, null
109+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
110+
//
111+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr {{%.*}}, null
112+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
113+
func sequenceOfNormalTrys(_ b1: Bool,
114+
_ b2: Bool,
115+
_ b3: Bool) throws -> Int {
116+
let x = try getRandom(b1)
117+
let y = try getRandom(b2)
118+
let z = try getRandom(b3)
119+
return x + y + z
120+
}
121+
122+
// CHECK-IR: [[PREFER_FALSE]] = !{!"branch_weights", i32 1, i32 2000}

0 commit comments

Comments
 (0)