Skip to content

Commit 0848df8

Browse files
authored
[AutoDiff] Mark SILGen derivative function thunks as "always inline". (#27754)
SILGen derivative function thunks must be publicly exposed and cannot be `[transparent]`. Instead, mark them as "always inline" for optimization.
1 parent 519236c commit 0848df8

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3696,12 +3696,13 @@ SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
36963696
auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
36973697
original->getLinkage(), /*isDerivativeFnExported*/ true);
36983698
// This thunk is publicly exposed and cannot be transparent.
3699-
// TODO(TF-925): Mark the thunks as "always inline" for optimization.
3699+
// Instead, mark it as "always inline" for optimization.
37003700
auto *thunk = fb.getOrCreateFunction(
37013701
loc, name, linkage, origDerivativeFnType, IsBare, IsNotTransparent,
37023702
derivativeFn->isSerialized(), derivativeFn->isDynamicallyReplaceable(),
37033703
derivativeFn->getEntryCount(), derivativeFn->isThunk(),
37043704
derivativeFn->getClassSubclassScope());
3705+
thunk->setInlineStrategy(AlwaysInline);
37053706
if (!thunk->empty())
37063707
return thunk;
37073708
thunk->setGenericEnvironment(thunkGenericEnv);

lib/SILGen/SILGenThunk.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ SILGenModule::getOrCreateAutoDiffDerivativeForwardingThunk(
8787
originalLinkage, /*isDerivativeFnExported*/ true);
8888
auto name = derivativeFnDeclRef.mangle();
8989
// This thunk is publicly exposed and cannot be transparent.
90-
// TODO(TF-925): Mark the thunks as "always inline" for optimization.
90+
// Instead, mark it as "always inline" for optimization.
9191
auto *thunk = builder.getOrCreateFunction(
9292
derivativeFnDecl, name, linkage, derivativeFnTy, IsBare, IsNotTransparent,
9393
derivativeFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(),
9494
IsThunk);
95+
thunk->setInlineStrategy(AlwaysInline);
9596
if (!thunk->empty())
9697
return thunk;
9798

test/AutoDiff/silgen_thunking/main.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func vjpNoReabstraction<T: Differentiable>(_ x: T) -> (T, (T.TangentVector) -> T
1919
return (x, { $0 })
2020
}
2121
// Find the non-`[transparent]` SILGen thunk.
22-
// CHECK-LABEL: sil hidden [thunk] [ossa] @AD__$s4main15noReabstractionyxxs15_DifferentiableRzlF__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector)
22+
// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main15noReabstractionyxxs15_DifferentiableRzlF__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector)
2323

2424
var DerivativeSILGenThunkTests = TestSuite("DerivativeSILGenThunks")
2525

@@ -51,7 +51,7 @@ struct SelfReordering : Differentiable & AdditiveArithmetic {
5151
return (value, { v in (Self(1), Self(2), Self(3)) })
5252
}
5353

54-
// CHECK-LABEL: sil hidden [ossa] @AD__$s4main14SelfReorderingV20threeParameterMethodyA2C_ACtF__jvp_src_0_wrt_0_1_2 : $@convention(method) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> (@owned SelfReordering, @owned @callee_guaranteed (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> @owned SelfReordering)
54+
// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main14SelfReorderingV20threeParameterMethodyA2C_ACtF__jvp_src_0_wrt_0_1_2 : $@convention(method) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> (@owned SelfReordering, @owned @callee_guaranteed (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> @owned SelfReordering)
5555
// CHECK: bb0([[X:%.*]] : @guaranteed $SelfReordering, [[Y:%.*]] : @guaranteed $SelfReordering, [[SELF:%.*]] : @guaranteed $SelfReordering):
5656
// CHECK: [[JVP:%.*]] = function_ref @$s4main14SelfReorderingV23jvpThreeParameterMethodyAC_A2C_A2CtctAC_ACtF
5757
// CHECK: [[JVP_RESULT:%.*]] = apply [[JVP]]([[X]], [[Y]], [[SELF]])
@@ -66,7 +66,7 @@ struct SelfReordering : Differentiable & AdditiveArithmetic {
6666
// CHECK: [[DF_RESULT:%.*]] = apply [[DF]]([[DSELF]], [[DX]], [[DY]])
6767
// CHECK: return [[DF_RESULT]]
6868

69-
// CHECK-LABEL: sil hidden [ossa] @AD__$s4main14SelfReorderingV20threeParameterMethodyA2C_ACtF__vjp_src_0_wrt_0_1_2 : $@convention(method) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> (@owned SelfReordering, @owned @callee_guaranteed (@guaranteed SelfReordering) -> (@owned SelfReordering, @owned SelfReordering, @owned SelfReordering))
69+
// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main14SelfReorderingV20threeParameterMethodyA2C_ACtF__vjp_src_0_wrt_0_1_2 : $@convention(method) (@guaranteed SelfReordering, @guaranteed SelfReordering, @guaranteed SelfReordering) -> (@owned SelfReordering, @owned @callee_guaranteed (@guaranteed SelfReordering) -> (@owned SelfReordering, @owned SelfReordering, @owned SelfReordering))
7070
// CHECK: bb0([[X:%.*]] : @guaranteed $SelfReordering, [[Y:%.*]] : @guaranteed $SelfReordering, [[SELF:%.*]] : @guaranteed $SelfReordering):
7171
// CHECK: [[VJP:%.*]] = function_ref @$s4main14SelfReorderingV23vjpThreeParameterMethodyAC_AC_A2CtACctAC_ACtF
7272
// CHECK: [[VJP_RESULT:%.*]] = apply [[VJP]]([[X]], [[Y]], [[SELF]])
@@ -115,7 +115,7 @@ where Dummy: Differentiable & ExpressibleByIntegerLiteral {
115115
return (value, { v in (v, 2.0, 3.0) })
116116
}
117117

118-
// CHECK-LABEL: sil hidden [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts15_DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__jvp_src_0_wrt_0_1_2 : $@convention(method) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral, τ_0_0 : _Differentiable><τ_1_0, τ_1_1 where τ_1_0 : _Differentiable, τ_1_1 : _Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector, @in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector) {
118+
// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts15_DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__jvp_src_0_wrt_0_1_2 : $@convention(method) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral, τ_0_0 : _Differentiable><τ_1_0, τ_1_1 where τ_1_0 : _Differentiable, τ_1_1 : _Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector, @in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector) {
119119
// CHECK: bb0([[JVP_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>, [[X:%.*]] : $*τ_1_0, [[Y:%.*]] : $*τ_1_1, [[SELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>):
120120
// CHECK: [[JVP:%.*]] = function_ref @$s4main21SelfReorderingGenericV23jvpThreeParameterMethodyACyxG_AC13TangentVectorVyx_GAH_AFQyd__AFQyd_0_tctqd___qd_0_ts15_DifferentiableRd__sAKRd_0_s25ExpressibleByFloatLiteralAIRQsAlJRQr0_lF
121121
// CHECK: [[DF:%.*]] = apply [[JVP]]<τ_0_0, τ_1_0, τ_1_1>([[JVP_RESULT]], [[X]], [[Y]], [[SELF]])
@@ -129,7 +129,7 @@ where Dummy: Differentiable & ExpressibleByIntegerLiteral {
129129
// CHECK: [[VOID:%.*]] = tuple ()
130130
// CHECK: return [[VOID]]
131131

132-
// CHECK-LABEL: sil hidden [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts15_DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__vjp_src_0_wrt_0_1_2 : $@convention(method) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral, τ_0_0 : _Differentiable><τ_1_0, τ_1_1 where τ_1_0 : _Differentiable, τ_1_1 : _Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> (@out τ_1_0.TangentVector, @out τ_1_1.TangentVector, @out SelfReorderingGeneric<τ_0_0>.TangentVector)) {
132+
// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts15_DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__vjp_src_0_wrt_0_1_2 : $@convention(method) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral, τ_0_0 : _Differentiable><τ_1_0, τ_1_1 where τ_1_0 : _Differentiable, τ_1_1 : _Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> (@out τ_1_0.TangentVector, @out τ_1_1.TangentVector, @out SelfReorderingGeneric<τ_0_0>.TangentVector)) {
133133
// CHECK: bb0([[VJP_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>, [[X:%.*]] : $*τ_1_0, [[Y:%.*]] : $*τ_1_1, [[SELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>):
134134
// CHECK: [[VJP:%.*]] = function_ref @$s4main21SelfReorderingGenericV23vjpThreeParameterMethodyACyxG_AC13TangentVectorVyx_G_AFQyd__AFQyd_0_tAHctqd___qd_0_ts15_DifferentiableRd__sAKRd_0_s25ExpressibleByFloatLiteralAIRQsAlJRQr0_lF
135135
// CHECK: [[PB:%.*]] = apply [[VJP]]<τ_0_0, τ_1_0, τ_1_1>([[VJP_RESULT]], [[X]], [[Y]], [[SELF]])

0 commit comments

Comments
 (0)