@@ -209,15 +209,16 @@ class DifferentiationTransformer {
209
209
210
210
ADContext &getContext () { return context; }
211
211
212
- // / Canonicalize the given witness, filling in JVP/VJPs if missing.
212
+ // / Canonicalize the given witness, filling in derivative functions if
213
+ // / missing.
213
214
// /
214
- // / \param explicitDifferentiable specifies whether the witness comes from an
215
- // / explicit `@differentiable` or `@derivative` attribute in the AST.
216
- // / If it does, we emit JVP/VJPs with the same linkage as the original
217
- // / so that they are linkable from other modules .
215
+ // / Generated derivative functions have the same linkage as the witness.
216
+ // /
217
+ // / \param serializeFunctions specifies whether generated functions should be
218
+ // / serialized .
218
219
bool canonicalizeDifferentiabilityWitness (
219
220
SILFunction *original, SILDifferentiabilityWitness *witness,
220
- DifferentiationInvoker invoker, bool explicitDifferentiable );
221
+ DifferentiationInvoker invoker, IsSerialized_t serializeFunctions );
221
222
222
223
// / Process the given `differentiable_function` instruction, filling in
223
224
// / missing derivative functions if necessary.
@@ -692,17 +693,31 @@ emitDerivativeFunctionReference(
692
693
originalFn->getLoweredFunctionType (), desiredParameterIndices,
693
694
contextualDerivativeGenSig);
694
695
minimalWitness = SILDifferentiabilityWitness::createDefinition (
695
- context.getModule (),
696
- originalFn->isSerialized () ? SILLinkage::Shared : SILLinkage::Hidden,
697
- originalFn, desiredParameterIndices, desiredResultIndices,
696
+ context.getModule (), SILLinkage::Private, originalFn,
697
+ desiredParameterIndices, desiredResultIndices,
698
698
derivativeConstrainedGenSig, /* jvp*/ nullptr ,
699
- /* vjp*/ nullptr , originalFn-> isSerialized () );
699
+ /* vjp*/ nullptr , /* isSerialized*/ false );
700
700
if (transformer.canonicalizeDifferentiabilityWitness (
701
- originalFn, minimalWitness, invoker,
702
- /* explicitDifferentiable*/ false ))
701
+ originalFn, minimalWitness, invoker, IsNotSerialized))
703
702
return None;
704
703
}
705
704
assert (minimalWitness);
705
+ if (original->getFunction ()->isSerialized () &&
706
+ !hasPublicVisibility (minimalWitness->getLinkage ())) {
707
+ enum { Inlinable = 0 , DefaultArgument = 1 };
708
+ unsigned fragileKind = Inlinable;
709
+ // FIXME: This is not a very robust way of determining if the function is
710
+ // a default argument. Also, we have not exhaustively listed all the kinds
711
+ // of fragility.
712
+ if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
713
+ fragileKind = DefaultArgument;
714
+ context.emitNondifferentiabilityError (
715
+ original, invoker, diag::autodiff_private_derivative_from_fragile,
716
+ fragileKind,
717
+ llvm::isa_and_nonnull<AbstractClosureExpr>(
718
+ originalFRI->getLoc ().getAsASTNode <Expr>()));
719
+ return None;
720
+ }
706
721
// TODO(TF-482): Move generic requirement checking logic to
707
722
// `getExactDifferentiabilityWitness` &
708
723
// `getOrCreateMinimalASTDifferentiabilityWitness`.
@@ -1503,12 +1518,11 @@ class VJPEmitter final
1503
1518
original->getASTContext ());
1504
1519
1505
1520
SILOptFunctionBuilder fb (context.getTransform ());
1506
- // The generated pullback linkage is set to Hidden because generated
1507
- // pullbacks are never called cross-module.
1508
- auto linkage = SILLinkage::Hidden;
1521
+ auto linkage =
1522
+ vjp->isSerialized () ? SILLinkage::Public : SILLinkage::Private;
1509
1523
auto *pullback = fb.createFunction (
1510
1524
linkage, pbName, pbType, pbGenericEnv, original->getLocation (),
1511
- original->isBare (), IsNotTransparent, original ->isSerialized (),
1525
+ original->isBare (), IsNotTransparent, vjp ->isSerialized (),
1512
1526
original->isDynamicallyReplaceable ());
1513
1527
pullback->setDebugScope (new (module )
1514
1528
SILDebugScope (original->getLocation (),
@@ -3114,18 +3128,20 @@ class JVPEmitter final
3114
3128
witness->getSILAutoDiffIndices (), jvp)),
3115
3129
differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
3116
3130
jvp, witness->getSILAutoDiffIndices (), activityInfo),
3117
- differentialBuilder(SILBuilder(* createEmptyDifferential (
3118
- context, original , witness, &differentialInfo))),
3131
+ differentialBuilder(SILBuilder(
3132
+ * createEmptyDifferential ( context, witness, &differentialInfo))),
3119
3133
diffLocalAllocBuilder(getDifferential()) {
3120
3134
// Create empty differential function.
3121
3135
context.recordGeneratedFunction (&getDifferential ());
3122
3136
}
3123
3137
3124
3138
static SILFunction *
3125
- createEmptyDifferential (ADContext &context, SILFunction *original,
3139
+ createEmptyDifferential (ADContext &context,
3126
3140
SILDifferentiabilityWitness *witness,
3127
3141
LinearMapInfo *linearMapInfo) {
3128
3142
auto &module = context.getModule ();
3143
+ auto *original = witness->getOriginalFunction ();
3144
+ auto *jvp = witness->getJVP ();
3129
3145
auto origTy = original->getLoweredFunctionType ();
3130
3146
auto lookupConformance = LookUpConformanceInModule (module .getSwiftModule ());
3131
3147
@@ -3186,12 +3202,11 @@ class JVPEmitter final
3186
3202
original->getASTContext ());
3187
3203
3188
3204
SILOptFunctionBuilder fb (context.getTransform ());
3189
- // The generated tangent linkage is set to Hidden because generated tangent
3190
- // are never called cross-module.
3191
- auto linkage = SILLinkage::Hidden;
3205
+ auto linkage =
3206
+ jvp->isSerialized () ? SILLinkage::Public : SILLinkage::Hidden;
3192
3207
auto *differential = fb.createFunction (
3193
3208
linkage, diffName, diffType, diffGenericEnv, original->getLocation (),
3194
- original->isBare (), IsNotTransparent, original ->isSerialized (),
3209
+ original->isBare (), IsNotTransparent, jvp ->isSerialized (),
3195
3210
original->isDynamicallyReplaceable ());
3196
3211
differential->setDebugScope (
3197
3212
new (module ) SILDebugScope (original->getLocation (), differential));
@@ -5783,7 +5798,7 @@ bool VJPEmitter::run() {
5783
5798
5784
5799
static SILFunction *createEmptyVJP (ADContext &context, SILFunction *original,
5785
5800
SILDifferentiabilityWitness *witness,
5786
- SILLinkage linkage ) {
5801
+ IsSerialized_t isSerialized ) {
5787
5802
LLVM_DEBUG ({
5788
5803
auto &s = getADDebugStream ();
5789
5804
s << " Creating VJP:\n\t " ;
@@ -5817,10 +5832,10 @@ static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
5817
5832
vjpGenericSig);
5818
5833
5819
5834
SILOptFunctionBuilder fb (context.getTransform ());
5820
- auto *vjp = fb.createFunction (linkage, vjpName, vjpType, vjpGenericEnv,
5821
- original-> getLocation (), original-> isBare () ,
5822
- IsNotTransparent , original->isSerialized () ,
5823
- original->isDynamicallyReplaceable ());
5835
+ auto *vjp = fb.createFunction (
5836
+ witness-> getLinkage (), vjpName, vjpType, vjpGenericEnv ,
5837
+ original-> getLocation () , original->isBare (), IsNotTransparent ,
5838
+ isSerialized, original->isDynamicallyReplaceable ());
5824
5839
vjp->setDebugScope (new (module ) SILDebugScope (original->getLocation (), vjp));
5825
5840
5826
5841
LLVM_DEBUG (llvm::dbgs () << " VJP type: " << vjp->getLoweredFunctionType ()
@@ -5830,7 +5845,7 @@ static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
5830
5845
5831
5846
static SILFunction *createEmptyJVP (ADContext &context, SILFunction *original,
5832
5847
SILDifferentiabilityWitness *witness,
5833
- SILLinkage linkage ) {
5848
+ IsSerialized_t isSerialized ) {
5834
5849
LLVM_DEBUG ({
5835
5850
auto &s = getADDebugStream ();
5836
5851
s << " Creating JVP:\n\t " ;
@@ -5864,10 +5879,10 @@ static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
5864
5879
LookUpConformanceInModule (module .getSwiftModule ()), jvpGenericSig);
5865
5880
5866
5881
SILOptFunctionBuilder fb (context.getTransform ());
5867
- auto *jvp = fb.createFunction (linkage, jvpName, jvpType, jvpGenericEnv,
5868
- original-> getLocation (), original-> isBare () ,
5869
- IsNotTransparent , original->isSerialized () ,
5870
- original->isDynamicallyReplaceable ());
5882
+ auto *jvp = fb.createFunction (
5883
+ witness-> getLinkage (), jvpName, jvpType, jvpGenericEnv ,
5884
+ original-> getLocation () , original->isBare (), IsNotTransparent ,
5885
+ isSerialized, original->isDynamicallyReplaceable ());
5871
5886
jvp->setDebugScope (new (module ) SILDebugScope (original->getLocation (), jvp));
5872
5887
5873
5888
LLVM_DEBUG (llvm::dbgs () << " JVP type: " << jvp->getLoweredFunctionType ()
@@ -5878,7 +5893,7 @@ static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
5878
5893
// / Returns true on error.
5879
5894
bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness (
5880
5895
SILFunction *original, SILDifferentiabilityWitness *witness,
5881
- DifferentiationInvoker invoker, bool explicitDifferentiable ) {
5896
+ DifferentiationInvoker invoker, IsSerialized_t serializeFunctions ) {
5882
5897
std::string traceMessage;
5883
5898
llvm::raw_string_ostream OS (traceMessage);
5884
5899
OS << " processing " ;
@@ -5889,9 +5904,6 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
5889
5904
5890
5905
assert (witness->isDefinition ());
5891
5906
5892
- auto derivativeFunctionLinkage =
5893
- explicitDifferentiable ? original->getLinkage () : SILLinkage::Hidden;
5894
-
5895
5907
// If the JVP doesn't exist, need to synthesize it.
5896
5908
if (!witness->getJVP ()) {
5897
5909
// Diagnose:
@@ -5903,7 +5915,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
5903
5915
return true ;
5904
5916
5905
5917
witness->setJVP (
5906
- createEmptyJVP (context, original, witness, derivativeFunctionLinkage ));
5918
+ createEmptyJVP (context, original, witness, serializeFunctions ));
5907
5919
context.recordGeneratedFunction (witness->getJVP ());
5908
5920
5909
5921
// For now, only do JVP generation if the flag is enabled and if custom VJP
@@ -5974,7 +5986,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
5974
5986
return true ;
5975
5987
5976
5988
witness->setVJP (
5977
- createEmptyVJP (context, original, witness, derivativeFunctionLinkage ));
5989
+ createEmptyVJP (context, original, witness, serializeFunctions ));
5978
5990
context.recordGeneratedFunction (witness->getVJP ());
5979
5991
VJPEmitter emitter (context, original, witness, witness->getVJP (), invoker);
5980
5992
return emitter.run ();
@@ -6731,7 +6743,7 @@ void Differentiation::run() {
6731
6743
auto invoker = invokerPair.second ;
6732
6744
6733
6745
if (transformer.canonicalizeDifferentiabilityWitness (
6734
- original, witness, invoker, /* explicitDifferentiable */ true ))
6746
+ original, witness, invoker, original-> isSerialized () ))
6735
6747
errorOccurred = true ;
6736
6748
}
6737
6749
0 commit comments