Skip to content

Commit c82b9f3

Browse files
authored
[AutoDiff] Directly SILGen @derivative attributes to diff witnesses. (#28621)
Previously, `@derivative` attribute type-checking created implicit `@differentiable` attributes on the original declaration. This was a longstanding hack powering `@derivative` attribute derivative registration. #28608 made these changes: - Derivative function configurations (from `@differentiable` and `@derivative` attributes) are serialized in modules and are loaded from imported modules. - The differentiation transform uses these derivative function configurations for derivative function lookup instead of `@differentiable` attributes. Now, `@derivative` attributes are directly lowered to differentiability witnesses during SILGen, and implicit `@differentiable` attribute generation is removed. `@derivative` attributes are now serialized. Resolves TF-835. Unblocks TF-1021: lifting the "same-file derivative registration only" limitation in `@derivative` attribute type-checking. This should be trivially possible but requires more testing. Exposes TF-1037: crash due to `@differentiable` + `@derivative` attribute with different derivative generic signatures. Exposes TF-1040: `@differentiable` attribute limitations for class methods. Exposes TF-1041: untested protocol requirement `@differentiable` attribute type-checking logic. Tracks TF-1042: remove `ASTContext::{Differentiable,Derivative}Attrs` and use `AbstractFunctionDecl::getDerivativeFunctionConfigurations` to detect duplicate `@differentiable` + `@derivative` attributes.
1 parent f0b1c0b commit c82b9f3

23 files changed

+579
-204
lines changed

include/swift/AST/ASTContext.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ namespace swift {
112112
class IndexSubset;
113113
// SWIFT_ENABLE_TENSORFLOW
114114
struct AutoDiffConfig;
115-
class VectorSpace;
115+
struct AutoDiffDerivativeFunctionKind;
116+
class DerivativeAttr;
116117
class DifferentiableAttr;
118+
class VectorSpace;
117119
// SWIFT_ENABLE_TENSORFLOW END
118120

119121
enum class KnownProtocolKind : uint8_t;
@@ -290,11 +292,26 @@ class ASTContext final {
290292
/// Cache of autodiff-associated vector spaces.
291293
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;
292294

293-
/// Cache of `@differentiable` attributes keyed by parameter indices. This
294-
/// helps us diagnose multiple `@differentiable`s that are with respect to the
295-
/// same set of parameters.
295+
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
296+
/// diagnose duplicate `@differentiable` attributes for the same key.
297+
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
298+
// signature as a key is possible. It requires derivative generic signature
299+
// mangling to avoid name collisions for SIL derivative functions with the
300+
// same parameter indices but different derivative generic signatures.
296301
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
297302
DifferentiableAttrs;
303+
304+
/// Cache of `@derivative` attributes keyed by parameter indices and
305+
/// derivative function kind. Used to diagnose duplicate `@derivative`
306+
/// attributes for the same key.
307+
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
308+
// signature as a key is possible. It requires derivative generic signature
309+
// mangling to avoid name collisions for SIL derivative functions with the
310+
// same parameter indices but different derivative generic signatures.
311+
llvm::DenseMap<
312+
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
313+
DerivativeAttr *>
314+
DerivativeAttrs;
298315
// SWIFT_ENABLE_TENSORFLOW END
299316

300317
private:

include/swift/AST/Attr.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ DECL_ATTR(differentiable, Differentiable,
513513
91)
514514
DECL_ATTR(derivative, Derivative,
515515
OnFunc | LongAttribute | AllowMultipleAttributes |
516-
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
517-
NotSerialized, 92)
516+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
517+
92)
518518
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
519519
OnAccessor | OnFunc | OnConstructor | OnSubscript |
520520
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
@@ -542,8 +542,8 @@ DECL_ATTR(quoted, Quoted,
542542
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
543543
DECL_ATTR(differentiating, Differentiating,
544544
OnFunc | LongAttribute | AllowMultipleAttributes |
545-
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
546-
NotSerialized, 98)
545+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
546+
98)
547547
// SWIFT_ENABLE_TENSORFLOW END
548548

549549
#undef TYPE_ATTR

include/swift/AST/Attr.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,7 +1841,7 @@ class DifferentiableAttr final
18411841

18421842
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
18431843
SourceRange baseRange, bool linear,
1844-
IndexSubset *indices,
1844+
IndexSubset *parameterIndices,
18451845
Optional<DeclNameWithLoc> jvp,
18461846
Optional<DeclNameWithLoc> vjp,
18471847
GenericSignature derivativeGenericSignature);
@@ -1855,9 +1855,10 @@ class DifferentiableAttr final
18551855
Optional<DeclNameWithLoc> vjp,
18561856
TrailingWhereClause *clause);
18571857

1858-
static DifferentiableAttr *create(Decl *original, bool implicit,
1859-
SourceLoc atLoc, SourceRange baseRange,
1860-
bool linear, IndexSubset *indices,
1858+
static DifferentiableAttr *create(AbstractFunctionDecl *original,
1859+
bool implicit, SourceLoc atLoc,
1860+
SourceRange baseRange, bool linear,
1861+
IndexSubset *parameterIndices,
18611862
Optional<DeclNameWithLoc> jvp,
18621863
Optional<DeclNameWithLoc> vjp,
18631864
GenericSignature derivativeGenSig);
@@ -1947,6 +1948,8 @@ class DerivativeAttr final
19471948
unsigned NumParsedParameters = 0;
19481949
/// The differentiation parameters' indices, resolved by the type checker.
19491950
IndexSubset *ParameterIndices = nullptr;
1951+
/// The derivative function kind (JVP or VJP), resolved by the type checker.
1952+
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
19501953

19511954
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
19521955
DeclNameWithLoc original,
@@ -1975,6 +1978,12 @@ class DerivativeAttr final
19751978
OriginalFunction = decl;
19761979
}
19771980

1981+
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
1982+
assert(Kind && "Derivative function kind has not yet been resolved");
1983+
return *Kind;
1984+
}
1985+
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
1986+
19781987
/// The parsed differentiation parameters, i.e. the list of parameters
19791988
/// specified in 'wrt:'.
19801989
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {

include/swift/AST/AutoDiff.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ struct AutoDiffConfig {
306306
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
307307
const AutoDiffDerivativeFunctionKind kind;
308308
IndexSubset *const parameterIndices;
309+
// TODO(TF-680): Mangle derivative generic signature requirements as well.
309310

310311
AutoDiffDerivativeFunctionIdentifier(
311312
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
@@ -508,6 +509,27 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
508509
}
509510
};
510511

512+
template<> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
513+
static AutoDiffDerivativeFunctionKind getEmptyKey() {
514+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
515+
DenseMapInfo<unsigned>::getEmptyKey());
516+
}
517+
518+
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
519+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
520+
DenseMapInfo<unsigned>::getTombstoneKey());
521+
}
522+
523+
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
524+
return DenseMapInfo<unsigned>::getHashValue(Val);
525+
}
526+
527+
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
528+
const AutoDiffDerivativeFunctionKind &RHS) {
529+
return LHS == RHS;
530+
}
531+
};
532+
511533
template<> struct DenseMapInfo<SILAutoDiffIndices> {
512534
static SILAutoDiffIndices getEmptyKey() {
513535
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

lib/AST/Attr.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
948948
Printer.printAttrName("@derivative");
949949
Printer << "(of: ";
950950
auto *attr = cast<DerivativeAttr>(this);
951-
auto *derivative = cast<AbstractFunctionDecl>(D);
952951
Printer << attr->getOriginalFunctionName().Name;
952+
auto *derivative = cast<AbstractFunctionDecl>(D);
953953
auto diffParamsString = getDifferentiationParametersClauseString(
954954
derivative, attr->getParameterIndices(), attr->getParsedParameters());
955955
if (!diffParamsString.empty())
@@ -963,8 +963,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
963963
Printer.printAttrName("@transpose");
964964
Printer << '(';
965965
auto *attr = cast<TransposeAttr>(this);
966-
auto *transpose = cast<AbstractFunctionDecl>(D);
967966
Printer << attr->getOriginalFunctionName().Name;
967+
auto *transpose = cast<AbstractFunctionDecl>(D);
968968
auto transParamsString = getTransposedParametersClauseString(
969969
transpose, attr->getParameterIndices(), attr->getParsedParameters());
970970
if (!transParamsString.empty())
@@ -1492,16 +1492,24 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
14921492
}
14931493

14941494
DifferentiableAttr *
1495-
DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc,
1496-
SourceRange baseRange, bool linear,
1497-
IndexSubset *indices, Optional<DeclNameWithLoc> jvp,
1495+
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
1496+
SourceLoc atLoc, SourceRange baseRange, bool linear,
1497+
IndexSubset *parameterIndices,
1498+
Optional<DeclNameWithLoc> jvp,
14981499
Optional<DeclNameWithLoc> vjp,
14991500
GenericSignature derivativeGenSig) {
15001501
auto &ctx = original->getASTContext();
15011502
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
15021503
alignof(DifferentiableAttr));
1504+
// Register derivative function configuration for the given original
1505+
// declaration.
1506+
// NOTE(TF-1038): `@differentiable` attributes currently always have
1507+
// effective result indices `{0}` (the first and only result index).
1508+
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
1509+
original->addDerivativeFunctionConfiguration(
1510+
{parameterIndices, resultIndices, derivativeGenSig});
15031511
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
1504-
linear, indices, std::move(jvp),
1512+
linear, parameterIndices, std::move(jvp),
15051513
std::move(vjp), derivativeGenSig);
15061514
}
15071515

lib/SILGen/SILGen.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
778778
diffAttr->getDerivativeGenericSignature());
779779
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
780780
}
781+
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
782+
SILFunction *jvp = nullptr;
783+
SILFunction *vjp = nullptr;
784+
switch (derivAttr->getDerivativeKind()) {
785+
case AutoDiffDerivativeFunctionKind::JVP:
786+
jvp = F;
787+
break;
788+
case AutoDiffDerivativeFunctionKind::VJP:
789+
vjp = F;
790+
break;
791+
}
792+
auto *origAFD = derivAttr->getOriginalFunction();
793+
auto *origFn = getFunction(SILDeclRef(origAFD), NotForDefinition);
794+
auto derivativeGenSig = AFD->getGenericSignature();
795+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
796+
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
797+
derivativeGenSig);
798+
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
799+
derivAttr);
800+
}
781801
};
782802
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
783803
if (accessor->isGetter())
@@ -790,21 +810,22 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
790810
void SILGenModule::emitDifferentiabilityWitness(
791811
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
792812
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
793-
const DeclAttribute *diffAttr) {
813+
const DeclAttribute *attr) {
814+
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
794815
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
795816
auto origSilFnType = originalFunction->getLoweredFunctionType();
796-
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
797-
config.parameterIndices, origFnType);
817+
auto *silParamIndices =
818+
autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType);
798819
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
799820
// parameters corresponding to captured variables. These parameters do not
800821
// appear in the type of `origFnType`.
801822
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
802823
// take `CaptureInfo` into account.
803-
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
804-
loweredParamIndices = loweredParamIndices->extendingCapacity(
824+
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
825+
silParamIndices = silParamIndices->extendingCapacity(
805826
getASTContext(), origSilFnType->getNumParameters());
806827
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
807-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
828+
SILAutoDiffIndices indices(/*source*/ 0, silParamIndices);
808829

809830
// Self reordering thunk is necessary if wrt at least two parameters,
810831
// including self.
@@ -818,14 +839,22 @@ void SILGenModule::emitDifferentiabilityWitness(
818839
};
819840
bool reorderSelf = shouldReorderSelf();
820841

821-
// Create new SIL differentiability witness.
842+
// Get or create new SIL differentiability witness.
843+
// Witness already exists when there are two `@derivative` attributes (JVP and
844+
// VJP) for the same derivative function configuration.
822845
// Witness JVP and VJP are set below.
823-
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
824-
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
825-
config.resultIndices, config.derivativeGenericSignature,
826-
/*jvp*/ nullptr, /*vjp*/ nullptr,
827-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
828-
diffAttr);
846+
AutoDiffConfig silConfig(silParamIndices, config.resultIndices,
847+
config.derivativeGenericSignature);
848+
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
849+
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
850+
if (!diffWitness) {
851+
diffWitness = SILDifferentiabilityWitness::createDefinition(
852+
M, originalFunction->getLinkage(), originalFunction,
853+
silConfig.parameterIndices, silConfig.resultIndices,
854+
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
855+
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
856+
attr);
857+
}
829858

830859
// Set derivative function in differentiability witness.
831860
auto setDerivativeInDifferentiabilityWitness =

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,6 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
586586
auto memberAssocContextualType =
587587
parentDC->mapTypeIntoContext(memberAssocInterfaceType);
588588
newMember->setInterfaceType(memberAssocInterfaceType);
589-
// newMember->setType(memberAssocContextualType);
590589
Pattern *memberPattern =
591590
new (C) NamedPattern(newMember, /*implicit*/ true);
592591
memberPattern->setType(memberAssocContextualType);
@@ -623,10 +622,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
623622
derivativeGenSig = extDecl->getGenericSignature();
624623
auto *diffableAttr = DifferentiableAttr::create(
625624
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
626-
/*linear*/ false, {}, None, None, derivativeGenSig);
625+
/*linear*/ false, /*parameterIndices*/ IndexSubset::get(C, 1, {0}),
626+
/*jvp*/ None, /*vjp*/ None, derivativeGenSig);
627627
member->getAttrs().add(diffableAttr);
628-
// Set getter `@differentiable` attribute parameter indices.
629-
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));
630628
}
631629
}
632630

0 commit comments

Comments
 (0)