Skip to content

Commit fe6998d

Browse files
authored
[AutoDiff] Enable derivative registration for imported C functions. (#29016)
- Edit `SILFunctionType::getAutoDiffDerivativeFunctionType`. - Return `@convention(thin)` derivative function types given `@convention(c)` original function types. - Necessary because derivative function types have multiple results, which is unsupported for `@convention(c)` functions. - Revert derivative symbol TBDGen to visit `@differentiable` and `@derivative` attributes instead of visiting derivative configurations. - Necessary for `@derivative` attributes with cross-file original functions. `TBDGenVisitor` does not visit functions in other files, so the derivative configurations of the cross-file original function would not be visited. - SILGen: strip external from differentiability witness linkages. - Clang-imported functions have `public_external` linkage, but their differentiability witnesses should have `public` linkage. - Set `SILDeclRef` `isForeign` correctly for original functions that require a foreign entry point (including imported functions). - Done in SILGen, TBDGen, derivative lookup, IRGen. Resolves TF-1087: `@derivative` attribute SILGen crash for foreign functions. Unblocks TF-1085: using `@derivative` attribute for stdlib tgmath functions.
1 parent c23afc8 commit fe6998d

File tree

12 files changed

+136
-31
lines changed

12 files changed

+136
-31
lines changed

lib/IRGen/GenDecl.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,9 +1078,6 @@ void IRGenerator::emitGlobalTopLevel(llvm::StringSet<> *linkerDirectives) {
10781078
// Emit differentiability witnesses.
10791079
for (auto &dw :
10801080
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
1081-
if (dw.isDeclaration())
1082-
continue;
1083-
10841081
// Emit into same IRGenModule as the original function.
10851082
// NOTE(TF-894): Investigate whether `getGenModule(dw.getVJP())` is
10861083
// significant/desirable; `getGenModule` seems relevant for multi-threaded
@@ -4487,7 +4484,7 @@ IRGenModule::getAddrOfWitnessTablePattern(const NormalProtocolConformance *conf,
44874484
}
44884485

44894486
// SWIFT_ENABLE_TENSORFLOW
4490-
/// Look up the address of a witness table.
4487+
/// Look up the address of a differentiability witness.
44914488
llvm::Constant *IRGenModule::getAddrOfDifferentiabilityWitness(
44924489
const SILDifferentiabilityWitness *witness, ConstantInit definition) {
44934490
auto entity = LinkEntity::forDifferentiabilityWitness(witness);

lib/IRGen/GenDiffWitness.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ void IRGenModule::emitSILDifferentiabilityWitness(
3434
if (dw->isDeclaration())
3535
return;
3636

37-
// Don't emit public_external witnesses.
38-
if (hasPublicVisibility(dw->getLinkage()) &&
39-
isAvailableExternally(dw->getLinkage()))
37+
// Don't emit `public_external` witnesses.
38+
if (dw->getLinkage() == SILLinkage::PublicExternal)
4039
return;
4140

4241
ConstantInitBuilder builder(*this);

lib/SIL/SILFunctionType.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,13 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
423423
if (getSubstGenericSignature() && derivativeFnGenSig &&
424424
!derivativeFnGenSig->areAllParamsConcrete())
425425
canGenSig = derivativeFnGenSig;
426-
return SILFunctionType::get(canGenSig, getExtInfo(), getCoroutineKind(),
426+
// If original function is `@convention(c)`, the derivative function should
427+
// have `@convention(thin)`. IRGen does not support `@convention(c)` functions
428+
// with multiple results.
429+
auto extInfo = getExtInfo();
430+
if (getRepresentation() == SILFunctionTypeRepresentation::CFunctionPointer)
431+
extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin);
432+
return SILFunctionType::get(canGenSig, extInfo, getCoroutineKind(),
427433
getCalleeConvention(), newParameters, getYields(),
428434
newResults, getOptionalErrorResult(),
429435
getSubstitutions(), isGenericSignatureImplied(),

lib/SILGen/SILGen.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -755,11 +755,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
755755
F->print(llvm::dbgs()));
756756

757757
// SWIFT_ENABLE_TENSORFLOW
758-
// Visit `@differentiable` attributes and generate SIL differentiability
759-
// witnesses.
760-
// TODO(TF-835): Visit `@derivative` attributes when type-checking no longer
761-
// generates implicit `@differentiable` attributes. See TF-835 for replacement
762-
// code.
758+
// Visit `@differentiable` amd `@derivative` attributes and generate SIL
759+
// differentiability witnesses.
763760
// Skip if the SILDeclRef is a:
764761
// - Default argument generator function.
765762
// - Thunk.
@@ -796,7 +793,9 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
796793
break;
797794
}
798795
auto *origAFD = derivAttr->getOriginalFunction();
799-
auto *origFn = getFunction(SILDeclRef(origAFD), NotForDefinition);
796+
auto origDeclRef =
797+
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
798+
auto *origFn = getFunction(origDeclRef, NotForDefinition);
800799
auto derivativeGenSig = AFD->getGenericSignature();
801800
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
802801
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
@@ -854,10 +853,13 @@ void SILGenModule::emitDifferentiabilityWitness(
854853
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
855854
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
856855
if (!diffWitness) {
856+
// Strip external from linkage of original function.
857+
// Necessary for Clang-imported functions, which have external linkage.
858+
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
857859
diffWitness = SILDifferentiabilityWitness::createDefinition(
858-
M, originalFunction->getLinkage(), originalFunction,
859-
silConfig.parameterIndices, silConfig.resultIndices,
860-
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
860+
M, linkage, originalFunction, silConfig.parameterIndices,
861+
silConfig.resultIndices, config.derivativeGenericSignature,
862+
/*jvp*/ nullptr, /*vjp*/ nullptr,
861863
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
862864
attr);
863865
}
@@ -881,8 +883,10 @@ void SILGenModule::emitDifferentiabilityWitness(
881883
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
882884
kind, config.parameterIndices, config.derivativeGenericSignature,
883885
getASTContext());
886+
auto origDeclRef = SILDeclRef(originalAFD)
887+
.asForeign(requiresForeignEntryPoint(originalAFD));
884888
derivativeThunk = getOrCreateAutoDiffDerivativeForwardingThunk(
885-
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
889+
origDeclRef.asAutoDiffDerivativeFunction(id), derivative,
886890
expectedDerivativeType);
887891
}
888892
// Check for existing same derivative.

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
331331
if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0))
332332
return nullptr;
333333

334-
// Explicit differentiability witnesses only exist on SILFunctions that come
334+
// Explicit differentiability witnesses only exist on SIL functions that come
335335
// from AST functions.
336336
auto *originalAFD = findAbstractFunctionDecl(original);
337337
if (!originalAFD)
@@ -343,8 +343,16 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
343343
if (!minimalConfig)
344344
return nullptr;
345345

346-
auto *existingWitness = module.lookUpDifferentiabilityWitness(
347-
{original->getName(), *minimalConfig});
346+
std::string originalName = original->getName();
347+
// If original function requires a foreign entry point, use the foreign SIL
348+
// function to get or create the minimal differentiability witness.
349+
if (requiresForeignEntryPoint(originalAFD)) {
350+
originalName = SILDeclRef(originalAFD).asForeign().mangle();
351+
original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign());
352+
}
353+
354+
auto *existingWitness =
355+
module.lookUpDifferentiabilityWitness({originalName, *minimalConfig});
348356
if (existingWitness)
349357
return existingWitness;
350358

@@ -358,6 +366,5 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
358366
minimalConfig->derivativeGenericSignature);
359367
}
360368

361-
362369
} // end namespace autodiff
363370
} // end namespace swift

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
472472

473473
// TODO: Pattern match to make sure there is at least one `store` to the
474474
// array's active buffer.
475-
// if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
476475
if (isArrayLiteralIntrinsic(applySite) && hasActiveResults)
477476
return true;
478477

lib/TBDGen/TBDGen.cpp

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
241241
AutoDiffConfig config,
242242
AutoDiffLinearMapKind kind) {
243243
auto &ctx = original->getASTContext();
244-
auto declRef = SILDeclRef(original);
244+
auto declRef =
245+
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
245246

246247
if (!declRef.isSerialized())
247248
return;
@@ -270,31 +271,46 @@ void TBDGenVisitor::addAutoDiffDerivativeFunction(
270271
auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get(
271272
kind, parameterIndices, derivativeGenericSignature,
272273
original->getASTContext());
273-
addSymbol(SILDeclRef(original).asAutoDiffDerivativeFunction(assocFnId));
274+
auto declRef =
275+
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
276+
addSymbol(declRef.asAutoDiffDerivativeFunction(assocFnId));
274277
}
275278

276279
void TBDGenVisitor::addDifferentiabilityWitness(
277280
AbstractFunctionDecl *original, IndexSubset *astParameterIndices,
278281
IndexSubset *resultIndices, GenericSignature derivativeGenericSignature) {
279-
if (SILDeclRef(original).getLinkage(ForDefinition) != SILLinkage::Public)
282+
bool foreign = requiresForeignEntryPoint(original);
283+
auto declRef = SILDeclRef(original).asForeign(foreign);
284+
285+
// Skip symbol emission for original functions that do not have public
286+
// linkage. Exclude original functions that require a foreign entry point with
287+
// `public_external` linkage.
288+
auto originalLinkage = declRef.getLinkage(ForDefinition);
289+
if (foreign)
290+
originalLinkage = stripExternalFromLinkage(originalLinkage);
291+
if (originalLinkage != SILLinkage::Public)
280292
return;
281293

282294
auto *silParamIndices = autodiff::getLoweredParameterIndices(
283295
astParameterIndices,
284296
original->getInterfaceType()->castTo<AnyFunctionType>());
285297

286-
std::string originalMangledName = SILDeclRef(original).mangle();
298+
auto originalMangledName = declRef.mangle();
287299
AutoDiffConfig config{silParamIndices, resultIndices,
288300
derivativeGenericSignature};
289301
SILDifferentiabilityWitnessKey key(originalMangledName, config);
290302

291303
Mangle::ASTMangler mangler;
292-
std::string mangledName = mangler.mangleSILDifferentiabilityWitnessKey(key);
304+
auto mangledName = mangler.mangleSILDifferentiabilityWitnessKey(key);
293305
addSymbol(mangledName);
294306
}
295307

296308
void TBDGenVisitor::addDerivativeConfiguration(AbstractFunctionDecl *original,
297309
AutoDiffConfig config) {
310+
auto inserted = AddedDerivatives.insert({original, config});
311+
if (!inserted.second)
312+
return;
313+
298314
addAutoDiffLinearMapFunction(original, config,
299315
AutoDiffLinearMapKind::Differential);
300316
addAutoDiffLinearMapFunction(original, config,
@@ -374,9 +390,21 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
374390
}
375391

376392
// SWIFT_ENABLE_TENSORFLOW
377-
for (auto derivativeConfig : AFD->getDerivativeFunctionConfigurations()) {
378-
addDerivativeConfiguration(AFD, derivativeConfig);
379-
}
393+
for (const auto *differentiableAttr :
394+
AFD->getAttrs().getAttributes<DifferentiableAttr>())
395+
addDerivativeConfiguration(
396+
AFD,
397+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
398+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
399+
differentiableAttr->getDerivativeGenericSignature()));
400+
for (const auto *derivativeAttr :
401+
AFD->getAttrs().getAttributes<DerivativeAttr>())
402+
addDerivativeConfiguration(
403+
derivativeAttr->getOriginalFunction(),
404+
AutoDiffConfig(derivativeAttr->getParameterIndices(),
405+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
406+
AFD->getGenericSignature()));
407+
// SWIFT_ENABLE_TENSORFLOW END
380408

381409
visitDefaultArguments(AFD, AFD->getParameters());
382410
}
@@ -430,6 +458,16 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) {
430458
ASD->visitEmittedAccessors([&](AccessorDecl *accessor) {
431459
visitFuncDecl(accessor);
432460
});
461+
462+
// SWIFT_ENABLE_TENSORFLOW
463+
for (const auto *differentiableAttr :
464+
ASD->getAttrs().getAttributes<DifferentiableAttr>())
465+
addDerivativeConfiguration(
466+
ASD->getAccessor(AccessorKind::Get),
467+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
468+
IndexSubset::get(ASD->getASTContext(), 1, {0}),
469+
differentiableAttr->getDerivativeGenericSignature()));
470+
// SWIFT_ENABLE_TENSORFLOW END
433471
}
434472

435473
void TBDGenVisitor::visitVarDecl(VarDecl *VD) {

lib/TBDGen/TBDGenVisitor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ class TBDGenVisitor : public ASTVisitor<TBDGenVisitor> {
5656
const TBDGenOptions &Opts;
5757
Decl* TopLevelDecl = nullptr;
5858

59+
// SWIFT_ENABLE_TENSORFLOW
60+
/// A set of original function and derivative configuration pairs for which
61+
/// derivative symbols have been emitted.
62+
///
63+
/// Used to deduplicate derivative symbol emission for `@differentiable` and
64+
/// `@derivative` attributes.
65+
llvm::DenseSet<std::pair<AbstractFunctionDecl *, AutoDiffConfig>>
66+
AddedDerivatives;
67+
5968
private:
6069
void addSymbolInternal(StringRef name, llvm::MachO::SymbolKind kind,
6170
bool isLinkerDirective = false);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
float cFunction(float x) { return x; }
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
float cFunction(float);
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
module CForeign {
2+
header "Foreign.h"
3+
export *
4+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %clang -shared %S/Inputs/Foreign.c -fmodules -o %t/%target-library-name(CForeign)
3+
// RUN: %target-swift-emit-silgen -Xllvm -enable-experimental-cross-file-derivative-registration -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SILGEN --check-prefix=CHECK
4+
// RUN: %target-swift-emit-sil -Xllvm -enable-experimental-cross-file-derivative-registration -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SIL --check-prefix=CHECK
5+
// RUN: %target-build-swift -Xllvm -enable-experimental-cross-file-derivative-registration -I %S/Inputs -I %t %s -L %t -lCForeign
6+
7+
import CForeign
8+
9+
// TF-1087: Test derivative registration for foreign declaration (Clang-imported).
10+
// Original SILDeclRef must have `isForeign` bit set correctly.
11+
12+
// CHECK-SILGEN-LABEL: // differentiability witness for cFunction
13+
// CHECK-SILGEN: sil_differentiability_witness [serialized] [parameters 0] [results 0] @cFunction : $@convention(c) (Float) -> Float {
14+
// CHECK-SILGEN: vjp: @AD__cFunction__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
15+
// CHECK-SILGEN: }
16+
17+
// CHECK-SIL-LABEL: // differentiability witness for cFunction
18+
// CHECK-SIL: sil_differentiability_witness [serialized] [parameters 0] [results 0] @cFunction : $@convention(c) (Float) -> Float {
19+
// CHECK-SIL: jvp: @AD__cFunction__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
20+
// CHECK-SIL: vjp: @AD__cFunction__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
21+
// CHECK-SIL: }
22+
23+
// Check that original SIL function is correct.
24+
25+
// CHECK-SILGEN-LABEL: sil [serializable] [clang cFunction] @cFunction : $@convention(c) (Float) -> Float
26+
27+
@inlinable
28+
@derivative(of: cFunction)
29+
func vjpCFunction(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
30+
(cFunction(x), { $0 })
31+
}
32+
33+
@_silgen_name("test_derivative")
34+
@differentiable
35+
func testDerivative(_ x: Float) -> Float {
36+
cFunction(x)
37+
}
38+
39+
// CHECK-SILGEN-LABEL: sil hidden [ossa] @test_derivative : $@convention(thin) (Float) -> Float {
40+
// CHECK-SILGEN: {{%.*}} = function_ref @cFunction : $@convention(c) (Float) -> Float

0 commit comments

Comments
 (0)