Skip to content

Commit 7c384f8

Browse files
authored
[AutoDiff] SILGen derivative function thunks must not be transparent. (#27752)
SILGen generates thunks for derivative functions registered via `@differentiable` and `@differentiating` attributes. The thunks have a canonical, consistent naming (currently based on the original function name and parameter indices) that TBDGen can also generate via `SILDeclRef::mangle`. These SILGen derivative function thunks must not be transparent; otherwise, the functions will be inlined during MandatoryInlining and will not be exposed publicly. This bug was discovered as a DeadFunctionElimination crash: `function_ref` instructions referencing transparent derivative thunks had their referenced functions set to null by MandatoryInling and became invalid.
1 parent fe96839 commit 7c384f8

File tree

4 files changed

+23
-1
lines changed

4 files changed

+23
-1
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3695,6 +3695,8 @@ SILGenModule::getOrCreateAutoDiffDerivativeFunctionThunk(
36953695
SILGenFunctionBuilder fb(*this);
36963696
auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
36973697
original->getLinkage(), /*isDerivativeFnExported*/ true);
3698+
// This thunk is publicly exposed and cannot be transparent.
3699+
// TODO(TF-925): Mark the thunks as "always inline" for optimization.
36983700
auto *thunk = fb.getOrCreateFunction(
36993701
loc, name, linkage, origDerivativeFnType, IsBare, IsNotTransparent,
37003702
derivativeFn->isSerialized(), derivativeFn->isDynamicallyReplaceable(),

lib/SILGen/SILGenThunk.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef derivativeFnDeclRef,
8686
auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
8787
originalLinkage, /*isDerivativeFnExported*/ true);
8888
auto name = derivativeFnDeclRef.mangle();
89+
// This thunk is publicly exposed and cannot be transparent.
90+
// TODO(TF-925): Mark the thunks as "always inline" for optimization.
8991
auto *thunk = builder.getOrCreateFunction(
90-
derivativeFnDecl, name, linkage, derivativeFnTy, IsBare, IsTransparent,
92+
derivativeFnDecl, name, linkage, derivativeFnTy, IsBare, IsNotTransparent,
9193
derivativeFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(),
9294
IsThunk);
9395
if (!thunk->empty())

test/AutoDiff/silgen_thunking/main.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010
import StdlibUnittest
1111
import DifferentiationUnittest
1212

13+
// Verify that SILGen derivative thunks are never `[transparent]`.
14+
@differentiable(vjp: vjpNoReabstraction)
15+
func noReabstraction<T: Differentiable>(_ x: T) -> T {
16+
return x
17+
}
18+
func vjpNoReabstraction<T: Differentiable>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector) {
19+
return (x, { $0 })
20+
}
21+
// Find the non-`[transparent]` SILGen thunk.
22+
// CHECK-LABEL: sil hidden [thunk] [ossa] @AD__$s4main15noReabstractionyxxs15_DifferentiableRzlF__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector)
23+
1324
var DerivativeSILGenThunkTests = TestSuite("DerivativeSILGenThunks")
1425

1526
// TF-619: Test cross-module import of `@differentiable` methods with

test/AutoDiff/tbdgen.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
@differentiable public func publicDiffable(_ x: Float, _ y: Float) -> Float { return x }
1313
@differentiable(wrt: (x)) public func publicDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }
1414

15+
// Tests SILGen derivative "forwarding thunk" (no derivative reabstraction/self-reordering).
16+
@differentiable(vjp: publicNoDerivativeReabstractionVJP)
17+
public func publicNoDerivativeReabstraction<T: Differentiable>(_ x: T) -> T { return x }
18+
public func publicNoDerivativeReabstractionVJP<T: Differentiable>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector) {
19+
return (x, { $0 })
20+
}
21+
1522
@differentiable internal func internalDiffable(_ x: Float, _ y: Float) -> Float { return x }
1623
@differentiable(wrt: (x)) internal func internalDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }
1724

0 commit comments

Comments
 (0)