Skip to content

Commit 8dfd83b

Browse files
author
Marc Rasi
committed
[AutoDiff] draft of lifting samefile derivative constriant
1 parent e5915f7 commit 8dfd83b

File tree

19 files changed

+1070
-713
lines changed

19 files changed

+1070
-713
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3016,12 +3016,15 @@ NOTE(derivative_attr_result_func_type_mismatch_note,none,
30163016
"%0 does not have expected type %1", (Identifier, Type))
30173017
NOTE(derivative_attr_result_func_original_note,none,
30183018
"%0 defined here", (DeclName))
3019-
ERROR(derivative_attr_not_in_same_file_as_original,none,
3020-
"derivative not in the same file as the original function", ())
30213019
ERROR(derivative_attr_original_stored_property_unsupported,none,
30223020
"cannot register derivative for stored property %0", (DeclName))
30233021
ERROR(derivative_attr_original_already_has_derivative,none,
30243022
"a derivative already exists for %0", (DeclName))
3023+
ERROR(derivative_attr_visibility_too_broad,none,
3024+
"derivative function visibility must be at least as restrictive as original function "
3025+
"visibility", ())
3026+
NOTE(derivative_attr_visibility_too_broad_note,none,
3027+
"original function defined here", ())
30253028

30263029
// @transpose
30273030
ERROR(transpose_params_clause_param_not_differentiable,none,

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class SILDifferentiabilityWitness
9696
SILDifferentiabilityWitnessKey getKey() const;
9797
SILModule &getModule() const { return Module; }
9898
SILLinkage getLinkage() const { return Linkage; }
99+
void setLinkage(SILLinkage linkage) { Linkage = linkage; }
99100
SILFunction *getOriginalFunction() const { return OriginalFunction; }
100101
const AutoDiffConfig &getConfig() const { return Config; }
101102
IndexSubset *getParameterIndices() const {
@@ -127,6 +128,7 @@ class SILDifferentiabilityWitness
127128
bool isDeclaration() const { return IsDeclaration; }
128129
bool isDefinition() const { return !IsDeclaration; }
129130
bool isSerialized() const { return IsSerialized; }
131+
void setSerialized(bool isSerialized) { IsSerialized = isSerialized; }
130132
const DeclAttribute *getAttribute() const { return Attribute; }
131133

132134
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.

lib/SILGen/SILGen.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
781781
"all functions with generic signatures");
782782
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
783783
diffAttr->getDerivativeGenericSignature());
784-
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
784+
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr,
785+
F->getLinkage());
785786
}
786787
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
787788
SILFunction *jvp = nullptr;
@@ -801,7 +802,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
801802
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
802803
derivativeGenSig);
803804
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
804-
derivAttr);
805+
derivAttr, F->getLinkage());
805806
}
806807
};
807808
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
@@ -815,7 +816,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
815816
void SILGenModule::emitDifferentiabilityWitness(
816817
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
817818
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
818-
const DeclAttribute *attr) {
819+
const DeclAttribute *attr, SILLinkage witnessLinkage) {
819820
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
820821
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
821822
auto origSilFnType = originalFunction->getLoweredFunctionType();
@@ -854,11 +855,19 @@ void SILGenModule::emitDifferentiabilityWitness(
854855
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
855856
if (!diffWitness) {
856857
diffWitness = SILDifferentiabilityWitness::createDefinition(
857-
M, originalFunction->getLinkage(), originalFunction,
858-
silConfig.parameterIndices, silConfig.resultIndices,
859-
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
860-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
861-
attr);
858+
M, witnessLinkage, originalFunction, silConfig.parameterIndices,
859+
silConfig.resultIndices, config.derivativeGenericSignature,
860+
/*jvp*/ nullptr, /*vjp*/ nullptr,
861+
/*isSerialized*/ hasPublicVisibility(witnessLinkage), attr);
862+
}
863+
864+
// Use the least restrictive declared linkage, so that e.g. a
865+
// `@differentiable` on `public` function with `@derivative`s on `internal`
866+
// functions results in a public witness. (Sema is responsible for diagnosing
867+
// forbidden combinations).
868+
if (witnessLinkage < diffWitness->getLinkage()) {
869+
diffWitness->setLinkage(witnessLinkage);
870+
diffWitness->setSerialized(hasPublicVisibility(witnessLinkage));
862871
}
863872

864873
// Set derivative function in differentiability witness.

lib/SILGen/SILGen.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
366366
SILFunction *originalFunction,
367367
const AutoDiffConfig &config,
368368
SILFunction *jvp, SILFunction *vjp,
369-
const DeclAttribute *diffAttr);
369+
const DeclAttribute *diffAttr,
370+
SILLinkage witnessLinkage);
370371
// SWIFT_ENABLE_TENSORFLOW END
371372

372373
/// Emit the lazy initializer function for a global pattern binding

lib/Sema/TypeCheckAttr.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3867,12 +3867,15 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
38673867
return;
38683868
}
38693869

3870-
// Reject different-file retroactive derivatives.
3871-
// TODO(TF-136): Lift this restriction now that SIL differentiability witness
3872-
// infrastructure is ready.
3873-
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
3874-
diagnoseAndRemoveAttr(attr,
3875-
diag::derivative_attr_not_in_same_file_as_original);
3870+
// Check that derivative visibility is at least as restricted as original
3871+
// function visibility.
3872+
if (derivative->getFormalAccessScope() !=
3873+
originalAFD->getFormalAccessScope() &&
3874+
!derivative->getFormalAccessScope().isChildOf(
3875+
originalAFD->getFormalAccessScope())) {
3876+
diagnoseAndRemoveAttr(attr, diag::derivative_attr_visibility_too_broad);
3877+
diagnose(originalAFD->getLoc(),
3878+
diag::derivative_attr_visibility_too_broad_note);
38763879
return;
38773880
}
38783881

lib/TBDGen/TBDGen.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ void TBDGenVisitor::addSymbol(StringRef name, SymbolKind kind) {
6767
if (StringSymbols && kind == SymbolKind::GlobalSymbol) {
6868
auto isNewValue = StringSymbols->insert(mangled).second;
6969
(void)isNewValue;
70+
if (!isNewValue)
71+
llvm::dbgs() << mangled << "\n";
7072
assert(isNewValue && "symbol appears twice");
7173
}
7274
}
@@ -236,6 +238,10 @@ void TBDGenVisitor::addDifferentiabilityWitness(
236238

237239
void TBDGenVisitor::addDerivativeConfiguration(AbstractFunctionDecl *original,
238240
AutoDiffConfig config) {
241+
auto inserted = AddedDerivatives.insert({original, config});
242+
if (!inserted.second)
243+
return;
244+
239245
addAutoDiffLinearMapFunction(original, config,
240246
AutoDiffLinearMapKind::Differential);
241247
addAutoDiffLinearMapFunction(original, config,
@@ -315,9 +321,20 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
315321
}
316322

317323
// SWIFT_ENABLE_TENSORFLOW
318-
for (auto derivativeConfig : AFD->getDerivativeFunctionConfigurations()) {
319-
addDerivativeConfiguration(AFD, derivativeConfig);
320-
}
324+
for (const auto *differentiableAttr :
325+
AFD->getAttrs().getAttributes<DifferentiableAttr>())
326+
addDerivativeConfiguration(
327+
AFD,
328+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
329+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
330+
differentiableAttr->getDerivativeGenericSignature()));
331+
for (const auto *derivativeAttr :
332+
AFD->getAttrs().getAttributes<DerivativeAttr>())
333+
addDerivativeConfiguration(
334+
derivativeAttr->getOriginalFunction(),
335+
AutoDiffConfig(derivativeAttr->getParameterIndices(),
336+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
337+
AFD->getGenericSignature()));
321338

322339
visitDefaultArguments(AFD, AFD->getParameters());
323340
}
@@ -371,6 +388,15 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) {
371388
ASD->visitEmittedAccessors([&](AccessorDecl *accessor) {
372389
visitFuncDecl(accessor);
373390
});
391+
392+
// SWIFT_ENABLE_TENSORFLOW
393+
for (const auto *differentiableAttr :
394+
ASD->getAttrs().getAttributes<DifferentiableAttr>())
395+
addDerivativeConfiguration(
396+
ASD->getAccessor(AccessorKind::Get),
397+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
398+
IndexSubset::get(ASD->getASTContext(), 1, {0}),
399+
differentiableAttr->getDerivativeGenericSignature()));
374400
}
375401

376402
void TBDGenVisitor::visitVarDecl(VarDecl *VD) {

lib/TBDGen/TBDGenVisitor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ class TBDGenVisitor : public ASTVisitor<TBDGenVisitor> {
5555
ModuleDecl *SwiftModule;
5656
const TBDGenOptions &Opts;
5757

58+
// SWIFT_ENABLE_TENSORFLOW
59+
/// Tracks derivatives that have been added to the TBD.
60+
///
61+
/// Different attributes trigger emission of the same derivatives (e.g.
62+
/// `@differentiable` and `@derivative(of:)`), so we use this to deduplicate
63+
/// the symbols associated with the derivatives in the TBD.
64+
llvm::DenseSet<std::pair<AbstractFunctionDecl *, AutoDiffConfig>>
65+
AddedDerivatives;
66+
5867
private:
5968
void addSymbol(StringRef name, llvm::MachO::SymbolKind kind =
6069
llvm::MachO::SymbolKind::GlobalSymbol);

0 commit comments

Comments
 (0)