Skip to content

[AutoDiff] improve symbol linkage #28582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,12 @@ NOTE(autodiff_expression_not_differentiable_note,none,
NOTE(autodiff_external_nondifferentiable_function,none,
"cannot differentiate functions that have not been marked "
"'@differentiable' and that are defined in other files", ())
NOTE(autodiff_private_derivative_from_fragile,none,
"differentiated functions in "
"%select{'@inlinable' functions|default arguments}0 must be marked "
"'@differentiable' or have a public '@derivative'"
"%select{|; this is not possible with a closure, make a top-level "
"function instead}1", (unsigned, bool))
NOTE(autodiff_nondifferentiable_argument,none,
"cannot differentiate through a non-differentiable argument; do you want "
"to use 'withoutDerivative(at:)'?", ())
Expand Down
5 changes: 5 additions & 0 deletions lib/IRGen/GenDiffWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ void IRGenModule::emitSILDifferentiabilityWitness(
if (dw->isDeclaration())
return;

// Don't emit public_external witnesses.
if (hasPublicVisibility(dw->getLinkage()) &&
isAvailableExternally(dw->getLinkage()))
return;

ConstantInitBuilder builder(*this);
auto diffWitnessContents = builder.beginStruct();

Expand Down
3 changes: 2 additions & 1 deletion lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,8 @@ void SILGenModule::emitDifferentiabilityWitness(
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
config.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr, originalFunction->isSerialized(),
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't always serialize witnesses for any public declarations, because things we emit should be opaque by default. Maybe you should add logic to the cross module optimizer that was recently added (#28407) to make differentiability witnesses serialized.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this later because this PR is now very tested and almost ready to go. https://bugs.swift.org/browse/TF-1035

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks.

diffAttr);

// Set derivative function in differentiability witness.
Expand Down
92 changes: 52 additions & 40 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,16 @@ class DifferentiationTransformer {

ADContext &getContext() { return context; }

/// Canonicalize the given witness, filling in JVP/VJPs if missing.
/// Canonicalize the given witness, filling in derivative functions if
/// missing.
///
/// \param explicitDifferentiable specifies whether the witness comes from an
/// explicit `@differentiable` or `@derivative` attribute in the AST.
/// If it does, we emit JVP/VJPs with the same linkage as the original
/// so that they are linkable from other modules.
/// Generated derivative functions have the same linkage as the witness.
///
/// \param serializeFunctions specifies whether generated functions should be
/// serialized.
bool canonicalizeDifferentiabilityWitness(
SILFunction *original, SILDifferentiabilityWitness *witness,
DifferentiationInvoker invoker, bool explicitDifferentiable);
DifferentiationInvoker invoker, IsSerialized_t serializeFunctions);

/// Process the given `differentiable_function` instruction, filling in
/// missing derivative functions if necessary.
Expand Down Expand Up @@ -691,17 +692,31 @@ emitDerivativeFunctionReference(
originalFn->getLoweredFunctionType(), desiredParameterIndices,
contextualDerivativeGenSig);
minimalWitness = SILDifferentiabilityWitness::createDefinition(
context.getModule(),
originalFn->isSerialized() ? SILLinkage::Shared : SILLinkage::Hidden,
originalFn, desiredParameterIndices, desiredResultIndices,
context.getModule(), SILLinkage::Private, originalFn,
desiredParameterIndices, desiredResultIndices,
derivativeConstrainedGenSig, /*jvp*/ nullptr,
/*vjp*/ nullptr, originalFn->isSerialized());
/*vjp*/ nullptr, /*isSerialized*/ false);
if (transformer.canonicalizeDifferentiabilityWitness(
originalFn, minimalWitness, invoker,
/*explicitDifferentiable*/ false))
originalFn, minimalWitness, invoker, IsNotSerialized))
return None;
}
assert(minimalWitness);
if (original->getFunction()->isSerialized() &&
!hasPublicVisibility(minimalWitness->getLinkage())) {
enum { Inlinable = 0, DefaultArgument = 1 };
unsigned fragileKind = Inlinable;
// FIXME: This is not a very robust way of determining if the function is
// a default argument. Also, we have not exhaustively listed all the kinds
// of fragility.
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
fragileKind = DefaultArgument;
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_private_derivative_from_fragile,
fragileKind,
llvm::isa_and_nonnull<AbstractClosureExpr>(
originalFRI->getLoc().getAsASTNode<Expr>()));
return None;
}
// TODO(TF-482): Move generic requirement checking logic to
// `getExactDifferentiabilityWitness` &
// `getOrCreateMinimalASTDifferentiabilityWitness`.
Expand Down Expand Up @@ -1502,12 +1517,11 @@ class VJPEmitter final
original->getASTContext());

SILOptFunctionBuilder fb(context.getTransform());
// The generated pullback linkage is set to Hidden because generated
// pullbacks are never called cross-module.
auto linkage = SILLinkage::Hidden;
auto linkage =
vjp->isSerialized() ? SILLinkage::Public : SILLinkage::Private;
auto *pullback = fb.createFunction(
linkage, pbName, pbType, pbGenericEnv, original->getLocation(),
original->isBare(), IsNotTransparent, original->isSerialized(),
original->isBare(), IsNotTransparent, vjp->isSerialized(),
original->isDynamicallyReplaceable());
pullback->setDebugScope(new (module)
SILDebugScope(original->getLocation(),
Expand Down Expand Up @@ -3275,18 +3289,20 @@ class JVPEmitter final
witness->getSILAutoDiffIndices(), jvp)),
differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
jvp, witness->getSILAutoDiffIndices(), activityInfo),
differentialBuilder(SILBuilder(*createEmptyDifferential(
context, original, witness, &differentialInfo))),
differentialBuilder(SILBuilder(
*createEmptyDifferential(context, witness, &differentialInfo))),
diffLocalAllocBuilder(getDifferential()) {
// Create empty differential function.
context.recordGeneratedFunction(&getDifferential());
}

static SILFunction *
createEmptyDifferential(ADContext &context, SILFunction *original,
createEmptyDifferential(ADContext &context,
SILDifferentiabilityWitness *witness,
LinearMapInfo *linearMapInfo) {
auto &module = context.getModule();
auto *original = witness->getOriginalFunction();
auto *jvp = witness->getJVP();
auto origTy = original->getLoweredFunctionType();
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());

Expand Down Expand Up @@ -3347,12 +3363,11 @@ class JVPEmitter final
original->getASTContext());

SILOptFunctionBuilder fb(context.getTransform());
// The generated tangent linkage is set to Hidden because generated tangent
// are never called cross-module.
auto linkage = SILLinkage::Hidden;
auto linkage =
jvp->isSerialized() ? SILLinkage::Public : SILLinkage::Hidden;
auto *differential = fb.createFunction(
linkage, diffName, diffType, diffGenericEnv, original->getLocation(),
original->isBare(), IsNotTransparent, original->isSerialized(),
original->isBare(), IsNotTransparent, jvp->isSerialized(),
original->isDynamicallyReplaceable());
differential->setDebugScope(
new (module) SILDebugScope(original->getLocation(), differential));
Expand Down Expand Up @@ -5938,7 +5953,7 @@ bool VJPEmitter::run() {

static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
SILDifferentiabilityWitness *witness,
SILLinkage linkage) {
IsSerialized_t isSerialized) {
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "Creating VJP:\n\t";
Expand Down Expand Up @@ -5972,10 +5987,10 @@ static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
vjpGenericSig);

SILOptFunctionBuilder fb(context.getTransform());
auto *vjp = fb.createFunction(linkage, vjpName, vjpType, vjpGenericEnv,
original->getLocation(), original->isBare(),
IsNotTransparent, original->isSerialized(),
original->isDynamicallyReplaceable());
auto *vjp = fb.createFunction(
witness->getLinkage(), vjpName, vjpType, vjpGenericEnv,
original->getLocation(), original->isBare(), IsNotTransparent,
isSerialized, original->isDynamicallyReplaceable());
vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp));

LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType()
Expand All @@ -5985,7 +6000,7 @@ static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,

static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
SILDifferentiabilityWitness *witness,
SILLinkage linkage) {
IsSerialized_t isSerialized) {
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "Creating JVP:\n\t";
Expand Down Expand Up @@ -6019,10 +6034,10 @@ static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
LookUpConformanceInModule(module.getSwiftModule()), jvpGenericSig);

SILOptFunctionBuilder fb(context.getTransform());
auto *jvp = fb.createFunction(linkage, jvpName, jvpType, jvpGenericEnv,
original->getLocation(), original->isBare(),
IsNotTransparent, original->isSerialized(),
original->isDynamicallyReplaceable());
auto *jvp = fb.createFunction(
witness->getLinkage(), jvpName, jvpType, jvpGenericEnv,
original->getLocation(), original->isBare(), IsNotTransparent,
isSerialized, original->isDynamicallyReplaceable());
jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp));

LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType()
Expand All @@ -6033,7 +6048,7 @@ static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
/// Returns true on error.
bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
SILFunction *original, SILDifferentiabilityWitness *witness,
DifferentiationInvoker invoker, bool explicitDifferentiable) {
DifferentiationInvoker invoker, IsSerialized_t serializeFunctions) {
std::string traceMessage;
llvm::raw_string_ostream OS(traceMessage);
OS << "processing ";
Expand All @@ -6044,9 +6059,6 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(

assert(witness->isDefinition());

auto derivativeFunctionLinkage =
explicitDifferentiable ? original->getLinkage() : SILLinkage::Hidden;

// If the JVP doesn't exist, need to synthesize it.
if (!witness->getJVP()) {
// Diagnose:
Expand All @@ -6058,7 +6070,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
return true;

witness->setJVP(
createEmptyJVP(context, original, witness, derivativeFunctionLinkage));
createEmptyJVP(context, original, witness, serializeFunctions));
context.recordGeneratedFunction(witness->getJVP());

// For now, only do JVP generation if the flag is enabled and if custom VJP
Expand Down Expand Up @@ -6129,7 +6141,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
return true;

witness->setVJP(
createEmptyVJP(context, original, witness, derivativeFunctionLinkage));
createEmptyVJP(context, original, witness, serializeFunctions));
context.recordGeneratedFunction(witness->getVJP());
VJPEmitter emitter(context, original, witness, witness->getVJP(), invoker);
return emitter.run();
Expand Down Expand Up @@ -6886,7 +6898,7 @@ void Differentiation::run() {
auto invoker = invokerPair.second;

if (transformer.canonicalizeDifferentiabilityWitness(
original, witness, invoker, /*explicitDifferentiable*/ true))
original, witness, invoker, original->isSerialized()))
errorOccurred = true;
}

Expand Down
11 changes: 8 additions & 3 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3534,8 +3534,9 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
assert(!isSerialized && "declaration must not be serialized");
}

auto linkage = fromStableSILLinkage(rawLinkage);
assert(linkage && "Expected value linkage for sil_differentiability_witness");
auto linkageOpt = fromStableSILLinkage(rawLinkage);
assert(linkageOpt &&
"Expected value linkage for sil_differentiability_witness");
auto originalName = MF->getIdentifierText(originalNameId);
auto jvpName = MF->getIdentifierText(jvpNameId);
auto vjpName = MF->getIdentifierText(vjpNameId);
Expand Down Expand Up @@ -3572,10 +3573,14 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
auto *diffWitness =
SILMod.lookUpDifferentiabilityWitness({originalName, config});

// Witnesses that we deserialize are always available externally; we never
// want to emit them ourselves.
auto linkage = swift::addExternalToLinkage(*linkageOpt);

// If there is no existing differentiability witness, create one.
if (!diffWitness)
diffWitness = SILDifferentiabilityWitness::createDeclaration(
SILMod, *linkage, original, parameterIndices, resultIndices,
SILMod, linkage, original, parameterIndices, resultIndices,
derivativeGenSig);

// If the current differentiability witness is merely a declaration, and the
Expand Down
27 changes: 27 additions & 0 deletions lib/TBDGen/TBDGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,30 @@ void TBDGenVisitor::addConformances(DeclContext *DC) {
}

// SWIFT_ENABLE_TENSORFLOW
void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
const DifferentiableAttr *attr,
AutoDiffLinearMapKind kind) {
auto declRef = SILDeclRef(original);

// Linear maps are only public when the original function is serialized.
if (!declRef.isSerialized())
return;

// Differentials are only emitted when forward mode is turned on.
if (kind == AutoDiffLinearMapKind::Differential &&
!original->getASTContext()
.LangOpts.EnableExperimentalForwardModeDifferentiation)
return;

auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
attr->getParameterIndices(),
original->getInterfaceType()->castTo<AnyFunctionType>());
Mangle::ASTMangler mangler;
std::string linearMapName = mangler.mangleAutoDiffLinearMapHelper(
declRef.mangle(), kind, SILAutoDiffIndices(0, loweredParamIndices));
addSymbol(linearMapName);
}

void TBDGenVisitor::addAutoDiffDerivativeFunction(
AbstractFunctionDecl *original, const DifferentiableAttr *attr,
AutoDiffDerivativeFunctionKind kind) {
Expand Down Expand Up @@ -208,6 +232,9 @@ void TBDGenVisitor::addDifferentiabilityWitness(

void TBDGenVisitor::addDifferentiableAttr(AbstractFunctionDecl *original,
const DifferentiableAttr *attr) {
addAutoDiffLinearMapFunction(original, attr,
AutoDiffLinearMapKind::Differential);
addAutoDiffLinearMapFunction(original, attr, AutoDiffLinearMapKind::Pullback);
addAutoDiffDerivativeFunction(original, attr,
AutoDiffDerivativeFunctionKind::JVP);
addAutoDiffDerivativeFunction(original, attr,
Expand Down
6 changes: 6 additions & 0 deletions lib/TBDGen/TBDGenVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class TBDGenVisitor : public ASTVisitor<TBDGenVisitor> {
void addBaseConformanceDescriptor(BaseConformance conformance);

// SWIFT_ENABLE_TENSORFLOW
/// Adds the symbol for the linear map function of the given kind associated
/// with the given original function and `@differentiable` attr.
void addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
const DifferentiableAttr *attr,
AutoDiffLinearMapKind kind);

/// Adds the symbol for the autodiff function of the given kind associated
/// with the given original function and `@differentiable` attr.
void addAutoDiffDerivativeFunction(AbstractFunctionDecl *original,
Expand Down
14 changes: 14 additions & 0 deletions test/AutoDiff/Inputs/e2e_cross_module_external_module.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import DifferentiationUnittest

@differentiable
public func doubleThenApplyDefaultF(_ x: Tracked<Float>) -> Tracked<Float> {
return x
}

@differentiable
public func doubleThenApply(
_ x: Tracked<Float>,
_ f: @differentiable (Tracked<Float>) -> Tracked<Float> = doubleThenApplyDefaultF
) -> Tracked<Float> {
return f(2 * x)
}
4 changes: 2 additions & 2 deletions test/AutoDiff/control_flow_sil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func cond(_ x: Float) -> Float {
// CHECK-SIL: return [[VJP_RESULT]]


// CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
// CHECK-SIL-LABEL: sil private [ossa] @AD__cond__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : @owned $_AD__cond_bb3__PB__src_0_wrt_0):
// CHECK-SIL: [[BB3_PRED:%.*]] = destructure_struct [[BB3_PB_STRUCT]] : $_AD__cond_bb3__PB__src_0_wrt_0
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1
Expand Down Expand Up @@ -217,7 +217,7 @@ func cond_tuple_var(_ x: Float) -> Float {
return y.1
}

// CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond_tuple_var__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
// CHECK-SIL-LABEL: sil private [ossa] @AD__cond_tuple_var__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_tuple_var_bb3__PB__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0):
// CHECK-SIL: [[BB3_PRED:%.*]] = destructure_struct [[BB3_PB_STRUCT]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
Expand Down
32 changes: 32 additions & 0 deletions test/AutoDiff/differentiation_transform_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,35 @@ _ = gradient(at: Float(1), Float(2), in: (+) as @differentiable (Float, @nondiff

// expected-error @+1 {{conversion to '@differentiable(linear)' function type is not yet supported}}
let _: @differentiable(linear) (Float) -> Float = { x in x }

//===----------------------------------------------------------------------===//
// Differentiating from fragile functions
//===----------------------------------------------------------------------===//

public func implicitlyDifferentiableFromFragile(_ x: Float) -> Float { x }

public func hasImplicitlyDifferentiatedTopLevelDefaultArgument(
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{differentiated functions in default arguments must be marked '@differentiable' or have a public '@derivative'}}
_ f: @differentiable (Float) -> Float = implicitlyDifferentiableFromFragile
) {}

// TODO(TF-1030): This will eventually not be an error.
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{differentiated functions in default arguments must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead}}
public func hasImplicitlyDifferentiatedClosureDefaultArgument(_ f: @differentiable (Float) -> Float = { $0 }) {}

@inlinable
public func fragileFuncWithGradient() {
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'}}
let _ = gradient(at: 0, in: implicitlyDifferentiableFromFragile)
}

@inlinable
@differentiable
public func fragileDifferentiable(_ x: Float) -> Float {
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'}}
implicitlyDifferentiableFromFragile(x)
}
Loading