@@ -95,6 +95,28 @@ static SILFunction *lookUpOrLinkFunction(StringRef name, SILModule &module) {
95
95
return module .findFunction (name, SILLinkage::PublicExternal);
96
96
}
97
97
98
+ // / Computes the correct linkage for functions generated by the AD pass
99
+ // / associated with a function with linkage `originalLinkage`.
100
+ static SILLinkage getAutoDiffFunctionLinkage (SILLinkage originalLinkage) {
101
+ // If the original is defined externally, then the AD pass is just generating
102
+ // associated functions for use in the current module and therefore these
103
+ // associated functions should not be visible outside the module.
104
+ if (isAvailableExternally (originalLinkage))
105
+ return SILLinkage::Hidden;
106
+
107
+ // If the original is public, then external modules may need to link the
108
+ // associated function. Make the associated function public.
109
+ if (originalLinkage == SILLinkage::Public)
110
+ return SILLinkage::Public;
111
+ if (originalLinkage == SILLinkage::PublicNonABI)
112
+ return SILLinkage::PublicNonABI;
113
+
114
+ // Otherwise, the original function is defined and used only in the current
115
+ // module, so external modules will never try to access the associated
116
+ // function. Make the associated function hidden.
117
+ return SILLinkage::Hidden;
118
+ }
119
+
98
120
// / Given a function, gather all of its formal results (both direct and
99
121
// / indirect) in an order defined by its result type. Note that "formal results"
100
122
// / refer to result values in the body of the function, not at call sites.
@@ -2104,8 +2126,11 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2104
2126
// Clone.
2105
2127
cloneFunctionBody (original, entry, entryArgs);
2106
2128
// If errors occurred, back out.
2107
- if (errorOccurred)
2129
+ if (errorOccurred) {
2130
+ // Delete the body so that later passes don't get confused by invalid SIL.
2131
+ getPrimal ()->getBlocks ().clear ();
2108
2132
return true ;
2133
+ }
2109
2134
auto *origExit = &*original->findReturnBB ();
2110
2135
auto *exit = BBMap.lookup (origExit);
2111
2136
assert (exit->getParent () == getPrimal ());
@@ -4131,8 +4156,10 @@ void DifferentiationTask::createEmptyPrimal() {
4131
4156
auto indices = getIndices ();
4132
4157
auto *original = getOriginal ();
4133
4158
auto &module = context.getModule ();
4134
- std::string primalName =
4135
- " AD__" + original->getName ().str () + " __primal_" + indices.mangle ();
4159
+ auto primalName = original->getASTContext ()
4160
+ .getIdentifier (" AD__" + original->getName ().str () +
4161
+ " __primal_" + indices.mangle ())
4162
+ .str ();
4136
4163
StructDecl *primalValueStructDecl = context.createPrimalValueStruct (this );
4137
4164
primalInfo = std::unique_ptr<PrimalInfo>(new PrimalInfo (primalValueStructDecl, module ));
4138
4165
auto pvType = primalValueStructDecl->getDeclaredType ()->getCanonicalType ();
@@ -4151,17 +4178,21 @@ void DifferentiationTask::createEmptyPrimal() {
4151
4178
originalTy->getParameters (), originalTy->getYields (), results,
4152
4179
originalTy->getOptionalErrorResult (), context.getASTContext ());
4153
4180
SILOptFunctionBuilder fb (context.getTransform ());
4154
- auto linkage = original->getLinkage ();
4155
- if (linkage == SILLinkage::Public)
4156
- linkage = SILLinkage::PublicNonABI;
4181
+ // We set generated primal linkage to Hidden because generated primals are
4182
+ // never called cross-module in VJP mode: all cross-module calls to associated
4183
+ // functions call the VJP.
4184
+ // TODO: In order for cross-module calls to work in non-VJP mode, we must use
4185
+ // `getAutoDiffFunctionLinkage` to make the linkage occasionally public. We'll
4186
+ // also need to update TBDGen to generate TBD entries for public primals.
4187
+ auto linkage = SILLinkage::Hidden;
4157
4188
primal = fb.getOrCreateFunction (
4158
4189
original->getLocation (), primalName, linkage, primalTy,
4159
4190
original->isBare (), original->isTransparent (), original->isSerialized ());
4160
4191
primal->setUnqualifiedOwnership ();
4161
4192
LLVM_DEBUG (getADDebugStream () << " Primal function created \n "
4162
4193
<< *primal << ' \n ' );
4163
4194
4164
- attr->setPrimalName (primal-> getName () );
4195
+ attr->setPrimalName (primalName );
4165
4196
}
4166
4197
4167
4198
void DifferentiationTask::createEmptyAdjoint () {
@@ -4263,24 +4294,31 @@ void DifferentiationTask::createEmptyAdjoint() {
4263
4294
->getCanonicalType ()));
4264
4295
}
4265
4296
4266
- auto adjName = " AD__" + original->getName ().str () + " __adjoint_" +
4267
- getIndices ().mangle ();
4297
+ auto adjName = original->getASTContext ()
4298
+ .getIdentifier (" AD__" + original->getName ().str () +
4299
+ " __adjoint_" + getIndices ().mangle ())
4300
+ .str ();
4268
4301
auto adjType = SILFunctionType::get (
4269
4302
origTy->getGenericSignature (), origTy->getExtInfo (),
4270
4303
origTy->getCoroutineKind (), origTy->getCalleeConvention (), adjParams, {},
4271
4304
adjResults, None, original->getASTContext ());
4272
4305
SILOptFunctionBuilder fb (context.getTransform ());
4273
- auto linkage = original->getLinkage ();
4274
- if (linkage == SILLinkage::Public)
4275
- linkage = SILLinkage::PublicNonABI;
4276
- adjoint = fb.createFunction (linkage, adjName, adjType,
4277
- original->getGenericEnvironment (), original->getLocation (),
4278
- original->isBare (), original->isTransparent (), original->isSerialized ());
4306
+ // We set generated adjoint linkage to Hidden because generated adjoints are
4307
+ // never called cross-module in VJP mode: all cross-module calls to associated
4308
+ // functions call the VJP.
4309
+ // TODO: In order for cross-module calls to work in non-VJP mode, we must use
4310
+ // `getAutoDiffFunctionLinkage` to make the linkage occasionally public. We'll
4311
+ // also need to update TBDGen to generate TBD entries for public adjoints.
4312
+ auto linkage = SILLinkage::Hidden;
4313
+ adjoint = fb.createFunction (
4314
+ linkage, adjName, adjType, original->getGenericEnvironment (),
4315
+ original->getLocation (), original->isBare (), original->isTransparent (),
4316
+ original->isSerialized ());
4279
4317
adjoint->setUnqualifiedOwnership ();
4280
4318
adjoint->setDebugScope (new (module )
4281
4319
SILDebugScope (original->getLocation (), adjoint));
4282
4320
4283
- attr->setAdjointName (adjoint-> getName () , /* primitive*/ false );
4321
+ attr->setAdjointName (adjName , /* primitive*/ false );
4284
4322
}
4285
4323
4286
4324
void DifferentiationTask::createVJP () {
@@ -4300,22 +4338,25 @@ void DifferentiationTask::createVJP() {
4300
4338
auto originalTy = original->getLoweredFunctionType ();
4301
4339
4302
4340
// === Create an empty VJP. ===
4303
- auto vjpName = " AD__" + original->getName ().str () + " __vjp_" +
4304
- getIndices ().mangle ();
4341
+ auto vjpName = original->getASTContext ()
4342
+ .getIdentifier (" AD__" + original->getName ().str () +
4343
+ " __vjp_" + getIndices ().mangle ())
4344
+ .str ();
4305
4345
auto vjpType = originalTy->getAutoDiffAssociatedFunctionType (
4306
4346
getIndices ().parameters , getIndices ().source , 1 ,
4307
4347
AutoDiffAssociatedFunctionKind::VJP, module ,
4308
4348
LookUpConformanceInModule (module .getSwiftModule ()));
4309
4349
4310
4350
SILOptFunctionBuilder fb (context.getTransform ());
4311
- vjp = fb.createFunction (
4312
- original->getLinkage (), vjpName, vjpType,
4313
- original->getGenericEnvironment (), original->getLocation (),
4314
- original->isBare (), original->isTransparent (), original->isSerialized ());
4351
+ auto linkage = getAutoDiffFunctionLinkage (original->getLinkage ());
4352
+ vjp = fb.createFunction (linkage, vjpName, vjpType,
4353
+ original->getGenericEnvironment (),
4354
+ original->getLocation (), original->isBare (),
4355
+ original->isTransparent (), original->isSerialized ());
4315
4356
vjp->setUnqualifiedOwnership ();
4316
4357
vjp->setDebugScope (new (module )
4317
4358
SILDebugScope (original->getLocation (), vjp));
4318
- attr->setVJPName (vjp-> getName () );
4359
+ attr->setVJPName (vjpName );
4319
4360
4320
4361
// Work around a bad interaction between VJPs, TFDeabstraction, and SIL
4321
4362
// optimizations.
0 commit comments