Skip to content

Commit a371c20

Browse files
committed
SIL: branch weights for try_apply's
1 parent 12c3da6 commit a371c20

File tree

6 files changed

+169
-12
lines changed

6 files changed

+169
-12
lines changed

include/swift/SIL/SILBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,13 @@ class SILBuilder {
555555
ArrayRef<SILValue> args, SILBasicBlock *normalBB, SILBasicBlock *errorBB,
556556
ApplyOptions options = ApplyOptions(),
557557
const GenericSpecializationInformation *specializationInfo = nullptr,
558-
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt) {
558+
std::optional<ApplyIsolationCrossing> isolationCrossing = std::nullopt,
559+
ProfileCounter normalCount = ProfileCounter(),
560+
ProfileCounter errorCount = ProfileCounter()) {
559561
return insertTerminator(TryApplyInst::create(
560562
getSILDebugLocation(loc), callee, subs, args, normalBB, errorBB,
561-
options, *F, specializationInfo, isolationCrossing));
563+
options, *F, specializationInfo, isolationCrossing,
564+
normalCount, errorCount));
562565
}
563566

564567
PartialApplyInst *createPartialApply(

include/swift/SIL/SILInstruction.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10968,7 +10968,8 @@ class TryApplyInstBase : public TermInst {
1096810968

1096910969
protected:
1097010970
TryApplyInstBase(SILInstructionKind valueKind, SILDebugLocation Loc,
10971-
SILBasicBlock *normalBB, SILBasicBlock *errorBB);
10971+
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
10972+
ProfileCounter normalCount, ProfileCounter errorCount);
1097210973

1097310974
public:
1097410975
SuccessorListTy getSuccessors() {
@@ -10988,6 +10989,11 @@ class TryApplyInstBase : public TermInst {
1098810989
const SILBasicBlock *getNormalBB() const { return DestBBs[NormalIdx]; }
1098910990
SILBasicBlock *getErrorBB() { return DestBBs[ErrorIdx]; }
1099010991
const SILBasicBlock *getErrorBB() const { return DestBBs[ErrorIdx]; }
10992+
10993+
/// The number of times the Normal branch was executed
10994+
ProfileCounter getNormalBBCount() const { return DestBBs[NormalIdx].getCount(); }
10995+
/// The number of times the Error branch was executed
10996+
ProfileCounter getErrorBBCount() const { return DestBBs[ErrorIdx].getCount(); }
1099110997
};
1099210998

1099310999
/// TryApplyInst - Represents the full application of a function that
@@ -11005,15 +11011,19 @@ class TryApplyInst final
1100511011
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
1100611012
ApplyOptions options,
1100711013
const GenericSpecializationInformation *specializationInfo,
11008-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11014+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11015+
ProfileCounter normalCount,
11016+
ProfileCounter errorCount);
1100911017

1101011018
static TryApplyInst *
1101111019
create(SILDebugLocation debugLoc, SILValue callee,
1101211020
SubstitutionMap substitutions, ArrayRef<SILValue> args,
1101311021
SILBasicBlock *normalBB, SILBasicBlock *errorBB, ApplyOptions options,
1101411022
SILFunction &parentFunction,
1101511023
const GenericSpecializationInformation *specializationInfo,
11016-
std::optional<ApplyIsolationCrossing> isolationCrossing);
11024+
std::optional<ApplyIsolationCrossing> isolationCrossing,
11025+
ProfileCounter normalCount,
11026+
ProfileCounter errorCount);
1101711027
};
1101811028

1101911029
/// DifferentiableFunctionInst - creates a `@differentiable` function-typed

lib/IRGen/IRGenSIL.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3884,9 +3884,18 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
38843884
// FIXME: Remove this when the following radar is fixed: rdar://116636601
38853885
Builder.CreatePtrToInt(errorValue, IGM.IntPtrTy);
38863886

3887+
// Emit profile metadata if available.
3888+
llvm::MDNode *Weights = nullptr;
3889+
auto NormalBBCount = tryApplyInst->getNormalBBCount();
3890+
auto ErrorBBCount = tryApplyInst->getErrorBBCount();
3891+
if (NormalBBCount || ErrorBBCount)
3892+
Weights = IGM.createProfileWeights(ErrorBBCount ? ErrorBBCount.getValue() : 0,
3893+
NormalBBCount ? NormalBBCount.getValue() : 0);
3894+
38873895
Builder.CreateCondBr(hasError,
38883896
typedErrorLoadBB ? typedErrorLoadBB : errorDest.bb,
3889-
normalDest.bb);
3897+
normalDest.bb,
3898+
Weights);
38903899

38913900
// Set up the PHI nodes on the normal edge.
38923901
unsigned firstIndex = 0;

lib/SIL/IR/SILInstructions.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -783,19 +783,24 @@ PartialApplyInst *PartialApplyInst::create(
783783
TryApplyInstBase::TryApplyInstBase(SILInstructionKind kind,
784784
SILDebugLocation loc,
785785
SILBasicBlock *normalBB,
786-
SILBasicBlock *errorBB)
787-
: TermInst(kind, loc), DestBBs{{{this, normalBB}, {this, errorBB}}} {}
786+
SILBasicBlock *errorBB,
787+
ProfileCounter normalCount,
788+
ProfileCounter errorCount)
789+
: TermInst(kind, loc), DestBBs{{{this, normalBB, normalCount},
790+
{this, errorBB, errorCount}}} {}
788791

789792
TryApplyInst::TryApplyInst(
790793
SILDebugLocation loc, SILValue callee, SILType substCalleeTy,
791794
SubstitutionMap subs, ArrayRef<SILValue> args,
792795
ArrayRef<SILValue> typeDependentOperands, SILBasicBlock *normalBB,
793796
SILBasicBlock *errorBB, ApplyOptions options,
794797
const GenericSpecializationInformation *specializationInfo,
795-
std::optional<ApplyIsolationCrossing> isolationCrossing)
798+
std::optional<ApplyIsolationCrossing> isolationCrossing,
799+
ProfileCounter normalCount,
800+
ProfileCounter errorCount)
796801
: InstructionBase(isolationCrossing, loc, callee, substCalleeTy, subs, args,
797802
typeDependentOperands, specializationInfo, normalBB,
798-
errorBB) {
803+
errorBB, normalCount, errorCount) {
799804
setApplyOptions(options);
800805
}
801806

@@ -805,19 +810,33 @@ TryApplyInst::create(SILDebugLocation loc, SILValue callee,
805810
SILBasicBlock *normalBB, SILBasicBlock *errorBB,
806811
ApplyOptions options, SILFunction &parentFunction,
807812
const GenericSpecializationInformation *specializationInfo,
808-
std::optional<ApplyIsolationCrossing> isolationCrossing) {
813+
std::optional<ApplyIsolationCrossing> isolationCrossing,
814+
ProfileCounter normalCount,
815+
ProfileCounter errorCount) {
809816
SILType substCalleeTy = callee->getType().substGenericArgs(
810817
parentFunction.getModule(), subs,
811818
parentFunction.getTypeExpansionContext());
812819

820+
if (parentFunction.getModule().getOptions().EnableThrowsPrediction &&
821+
!normalCount && !errorCount) {
822+
// Predict that the error branch is not taken.
823+
//
824+
// We cannot use the Expect builtin within SIL because try_apply abstracts
825+
// over the raw conditional test to see if an error was returned.
826+
// So, we synthesize profiling branch weights instead.
827+
normalCount = 1999;
828+
errorCount = 0;
829+
}
830+
813831
SmallVector<SILValue, 32> typeDependentOperands;
814832
collectTypeDependentOperands(typeDependentOperands, parentFunction,
815833
substCalleeTy.getASTType(), subs);
816834
void *buffer = allocateTrailingInst<TryApplyInst, Operand>(
817835
parentFunction, getNumAllOperands(args, typeDependentOperands));
818836
return ::new (buffer) TryApplyInst(
819837
loc, callee, substCalleeTy, subs, args, typeDependentOperands, normalBB,
820-
errorBB, options, specializationInfo, isolationCrossing);
838+
errorBB, options, specializationInfo, isolationCrossing,
839+
normalCount, errorCount);
821840
}
822841

823842
SILType DifferentiableFunctionInst::getDifferentiableFunctionType(

lib/SIL/IR/SILPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
15731573
visitApplyInstBase(AI);
15741574
*this << ", normal " << Ctx.getID(AI->getNormalBB());
15751575
*this << ", error " << Ctx.getID(AI->getErrorBB());
1576+
if (AI->getNormalBBCount())
1577+
*this << " !normal_count(" << AI->getNormalBBCount().getValue() << ")";
1578+
if (AI->getErrorBBCount())
1579+
*this << " !error_count(" << AI->getErrorBBCount().getValue() << ")";
15761580
}
15771581

15781582
void visitPartialApplyInst(PartialApplyInst *CI) {
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// RUN: %target-swift-frontend %s \
2+
// RUN: -enable-throws-prediction \
3+
// RUN: -sil-verify-all -module-name=test -emit-sil \
4+
// RUN: | %FileCheck --check-prefix CHECK-SIL %s
5+
6+
// RUN: %target-swift-frontend %s \
7+
// RUN: -enable-throws-prediction \
8+
// RUN: -sil-verify-all -module-name=test -emit-irgen \
9+
// RUN: | %FileCheck --check-prefix CHECK-IR %s
10+
11+
enum MyError: Error { case err }
12+
13+
func throwy1() throws {}
14+
func throwy2() throws(MyError) { }
15+
func throwy3() async throws -> Int { 0 }
16+
func throwy4() async throws(MyError) -> Int { 1 }
17+
18+
// CHECK-SIL-LABEL: sil hidden @$s4test0A13TryPredictionyySbF
19+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
20+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
21+
22+
// CHECK-IR-LABEL: define hidden swiftcc void @"$s4test0A13TryPredictionyySbF"
23+
// CHECK-IR: call swiftcc void @"$s4test7throwy1yyKF"
24+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
25+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
26+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
27+
28+
// CHECK-IR: call swiftcc void @"$s4test7throwy2yyAA7MyErrorOYKF"
29+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
30+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
31+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
32+
func testTryPrediction(_ b: Bool) {
33+
do {
34+
try throwy1()
35+
try throwy2()
36+
} catch {
37+
print("hi")
38+
}
39+
}
40+
41+
// CHECK-SIL-LABEL: sil hidden @$s4test0A21AsyncThrowsPredictionSiyYaF
42+
// CHECK-SIL: function_ref @$s4test7throwy3SiyYaKF
43+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
44+
45+
// CHECK-IR-LABEL: define hidden swifttailcc void @"$s4test0A21AsyncThrowsPredictionSiyYaF"
46+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
47+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
48+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
49+
func testAsyncThrowsPrediction() async -> Int {
50+
if let x = try? await throwy3() {
51+
return x
52+
}
53+
return 1337
54+
}
55+
56+
// CHECK-SIL-LABEL: sil hidden @$s4test0A28Async_TYPED_ThrowsPredictionSiyYaAA7MyErrorOYKF
57+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
58+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
59+
// CHECK-SIL: try_apply {{.*}} @error MyError{{.*}} !normal_count(1999) !error_count(0)
60+
61+
// CHECK-IR-LABEL: define hidden swifttailcc void @"$s4test0A28Async_TYPED_ThrowsPredictionSiyYaAA7MyErrorOYKF"
62+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
63+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
64+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
65+
//
66+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
67+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
68+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
69+
//
70+
// CHECK-IR: [[ERROR_PTR:%.*]] = load ptr, ptr %swifterror
71+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr [[ERROR_PTR]], null
72+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
73+
func testAsync_TYPED_ThrowsPrediction() async throws(MyError) -> Int {
74+
let x = try await throwy4()
75+
let y = try await throwy4()
76+
let z = try await throwy4()
77+
return x + y + z
78+
}
79+
80+
81+
func getRandom(_ b: Bool) throws -> Int {
82+
if b {
83+
return Int.random(in: 0..<1024)
84+
} else {
85+
throw MyError.err
86+
}
87+
}
88+
89+
// CHECK-SIL-LABEL: sil @$s4test20sequenceOfNormalTrysySiSb_S2btKF
90+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
91+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
92+
// CHECK-SIL: try_apply {{.*}} @error any Error{{.*}} !normal_count(1999) !error_count(0)
93+
94+
// CHECK-IR-LABEL: define swiftcc i64 @"$s4test20sequenceOfNormalTrysySiSb_S2btKF"
95+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr {{%.*}}, null
96+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
97+
//
98+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr {{%.*}}, null
99+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
100+
//
101+
// CHECK-IR: [[HAVE_ERROR:%.*]] = icmp ne ptr {{%.*}}, null
102+
// CHECK-IR: br i1 [[HAVE_ERROR]], {{.*}} !prof [[PREFER_FALSE:![0-9]+]]
103+
public func sequenceOfNormalTrys(_ b1: Bool,
104+
_ b2: Bool,
105+
_ b3: Bool) throws -> Int {
106+
let x = try getRandom(b1)
107+
let y = try getRandom(b2)
108+
let z = try getRandom(b3)
109+
return x + y + z
110+
}
111+
112+
// CHECK-IR: [[PREFER_FALSE]] = !{!"branch_weights", i32 1, i32 2000}

0 commit comments

Comments
 (0)