Skip to content

[AutoDiff] SILGen derivative function thunks must not be transparent. #27752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3695,6 +3695,8 @@ SILGenModule::getOrCreateAutoDiffDerivativeFunctionThunk(
SILGenFunctionBuilder fb(*this);
auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
original->getLinkage(), /*isDerivativeFnExported*/ true);
// This thunk is publicly exposed and cannot be transparent.
// TODO(TF-925): Mark the thunks as "always inline" for optimization.
auto *thunk = fb.getOrCreateFunction(
loc, name, linkage, origDerivativeFnType, IsBare, IsNotTransparent,
derivativeFn->isSerialized(), derivativeFn->isDynamicallyReplaceable(),
Expand Down
4 changes: 3 additions & 1 deletion lib/SILGen/SILGenThunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef derivativeFnDeclRef,
auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
originalLinkage, /*isDerivativeFnExported*/ true);
auto name = derivativeFnDeclRef.mangle();
// This thunk is publicly exposed and cannot be transparent.
// TODO(TF-925): Mark the thunks as "always inline" for optimization.
auto *thunk = builder.getOrCreateFunction(
derivativeFnDecl, name, linkage, derivativeFnTy, IsBare, IsTransparent,
derivativeFnDecl, name, linkage, derivativeFnTy, IsBare, IsNotTransparent,
derivativeFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(),
IsThunk);
if (!thunk->empty())
Expand Down
11 changes: 11 additions & 0 deletions test/AutoDiff/silgen_thunking/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
import StdlibUnittest
import DifferentiationUnittest

// Verify that SILGen derivative thunks are never `[transparent]`.
@differentiable(vjp: vjpNoReabstraction)
func noReabstraction<T: Differentiable>(_ x: T) -> T {
return x
}
func vjpNoReabstraction<T: Differentiable>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector) {
return (x, { $0 })
}
// Find the non-`[transparent]` SILGen thunk.
// CHECK-LABEL: sil hidden [thunk] [ossa] @AD__$s4main15noReabstractionyxxs15_DifferentiableRzlF__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector)

var DerivativeSILGenThunkTests = TestSuite("DerivativeSILGenThunks")

// TF-619: Test cross-module import of `@differentiable` methods with
Expand Down
7 changes: 7 additions & 0 deletions test/AutoDiff/tbdgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
@differentiable public func publicDiffable(_ x: Float, _ y: Float) -> Float { return x }
@differentiable(wrt: (x)) public func publicDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }

// Tests SILGen derivative "forwarding thunk" (no derivative reabstraction/self-reordering).
@differentiable(vjp: publicNoDerivativeReabstractionVJP)
public func publicNoDerivativeReabstraction<T: Differentiable>(_ x: T) -> T { return x }
public func publicNoDerivativeReabstractionVJP<T: Differentiable>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector) {
return (x, { $0 })
}

@differentiable internal func internalDiffable(_ x: Float, _ y: Float) -> Float { return x }
@differentiable(wrt: (x)) internal func internalDiffableWRT(_ x: Float, _ y: Float) -> Float { return x }

Expand Down