Skip to content

Commit c47a4ca

Browse files
committed
[AutoDiff] Directly SILGen @derivative attributes to diff witnesses.
Previously, `@derivative` attribute type-checking created implicit `@differentiable` attributes on the original declaration. This was a longstanding hack powering `@derivative` attribute derivative registration. swiftlang#28608 made these changes: - Derivative function configurations (from `@differentiable` and `@derivative` attributes) are serialized in modules and are loaded from imported modules. - The differentiation transform uses these derivative function configurations for derivative function lookup instead of `@differentiable` attributes. Now, `@derivative` attributes are directly lowered to differentiability witnesses during SILGen, and implicit `@differentiable` attribute generation is removed. Type-checking changes: - "Overlapping" `@differentiable` and `@derivative` attributes (for the same original declaration and parameter indices) are now disallowed. They semantically conflict because the first "requests derivative generation" while the second "registers a derivative". - "Overlapping" `@differentiable` and `@derivative` attributes are allowed for protocol requirements. Requirement `@differentiable` attributes mean "add JVP/VJP witness table entries" - not "request derivative generation", because there is no function body. - Note that relaxing the "overlapping" condition to consider derivative generic signatures is possible after derivative generic signature mangling for derivative functions: TF-680. Resolves TF-835. Unblocks TF-1021: lifting the "same-file derivative registration only" limitation in `@derivative` attribute type-checking. This should be possible without much work, but needs testing! Exposes TF-1040: `@differentiable` attribute limitations for class methods. Exposes TF-1041: untested protocol requirement `@differentiable` attribute type-checking logic.
1 parent ee7644d commit c47a4ca

12 files changed

+349
-154
lines changed

include/swift/AST/ASTContext.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ namespace swift {
112112
class IndexSubset;
113113
// SWIFT_ENABLE_TENSORFLOW
114114
struct AutoDiffConfig;
115-
class VectorSpace;
115+
struct AutoDiffDerivativeFunctionKind;
116+
class DerivativeAttr;
116117
class DifferentiableAttr;
118+
class VectorSpace;
117119
// SWIFT_ENABLE_TENSORFLOW END
118120

119121
enum class KnownProtocolKind : uint8_t;
@@ -290,11 +292,26 @@ class ASTContext final {
290292
/// Cache of autodiff-associated vector spaces.
291293
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;
292294

293-
/// Cache of `@differentiable` attributes keyed by parameter indices. This
294-
/// helps us diagnose multiple `@differentiable`s that are with respect to the
295-
/// same set of parameters.
295+
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
296+
/// diagnose duplicate `@differentiable` attributes for the same key.
297+
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
298+
// signature as a key is possible. It requires derivative generic signature
299+
// mangling to avoid name collisions for SIL derivative functions with the
300+
// same parameter indices but different derivative generic signatures.
296301
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
297302
DifferentiableAttrs;
303+
304+
/// Cache of `@derivative` attributes keyed by parameter indices and
305+
/// derivative function kind. Used to diagnose duplicate `@derivative`
306+
/// attributes for the same key.
307+
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
308+
// signature as a key is possible. It requires derivative generic signature
309+
// mangling to avoid name collisions for SIL derivative functions with the
310+
// same parameter indices but different derivative generic signatures.
311+
llvm::DenseMap<
312+
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
313+
DerivativeAttr *>
314+
DerivativeAttrs;
298315
// SWIFT_ENABLE_TENSORFLOW END
299316

300317
private:

include/swift/AST/Attr.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,6 +1947,8 @@ class DerivativeAttr final
19471947
unsigned NumParsedParameters = 0;
19481948
/// The differentiation parameters' indices, resolved by the type checker.
19491949
IndexSubset *ParameterIndices = nullptr;
1950+
/// The derivative function kind (JVP or VJP), resolved by the type checker.
1951+
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
19501952

19511953
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
19521954
DeclNameWithLoc original,
@@ -1975,6 +1977,12 @@ class DerivativeAttr final
19751977
OriginalFunction = decl;
19761978
}
19771979

1980+
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
1981+
assert(Kind && "Derivative function kind has not yet been resolved");
1982+
return *Kind;
1983+
}
1984+
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
1985+
19781986
/// The parsed differentiation parameters, i.e. the list of parameters
19791987
/// specified in 'wrt:'.
19801988
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {

include/swift/AST/AutoDiff.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ struct AutoDiffConfig {
306306
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
307307
const AutoDiffDerivativeFunctionKind kind;
308308
IndexSubset *const parameterIndices;
309+
// TODO(TF-680): Mangle derivative generic signature requirements as well.
309310

310311
AutoDiffDerivativeFunctionIdentifier(
311312
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
@@ -508,6 +509,27 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
508509
}
509510
};
510511

512+
template<> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
513+
static AutoDiffDerivativeFunctionKind getEmptyKey() {
514+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
515+
DenseMapInfo<unsigned>::getEmptyKey());
516+
}
517+
518+
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
519+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
520+
DenseMapInfo<unsigned>::getTombstoneKey());
521+
}
522+
523+
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
524+
return DenseMapInfo<unsigned>::getHashValue(Val);
525+
}
526+
527+
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
528+
const AutoDiffDerivativeFunctionKind &RHS) {
529+
return LHS == RHS;
530+
}
531+
};
532+
511533
template<> struct DenseMapInfo<SILAutoDiffIndices> {
512534
static SILAutoDiffIndices getEmptyKey() {
513535
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

lib/SILGen/SILGen.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
778778
diffAttr->getDerivativeGenericSignature());
779779
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
780780
}
781+
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
782+
SILFunction *jvp = nullptr;
783+
SILFunction *vjp = nullptr;
784+
switch (derivAttr->getDerivativeKind()) {
785+
case AutoDiffDerivativeFunctionKind::JVP:
786+
jvp = F;
787+
break;
788+
case AutoDiffDerivativeFunctionKind::VJP:
789+
vjp = F;
790+
break;
791+
}
792+
auto *origAFD = derivAttr->getOriginalFunction();
793+
auto *origFn = getFunction(SILDeclRef(origAFD), NotForDefinition);
794+
auto derivativeGenSig = AFD->getGenericSignature();
795+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
796+
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
797+
derivativeGenSig);
798+
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
799+
derivAttr);
800+
}
781801
};
782802
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
783803
if (accessor->isGetter())
@@ -790,21 +810,22 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
790810
void SILGenModule::emitDifferentiabilityWitness(
791811
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
792812
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
793-
const DeclAttribute *diffAttr) {
813+
const DeclAttribute *attr) {
814+
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
794815
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
795816
auto origSilFnType = originalFunction->getLoweredFunctionType();
796-
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
797-
config.parameterIndices, origFnType);
817+
auto *silParamIndices =
818+
autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType);
798819
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
799820
// parameters corresponding to captured variables. These parameters do not
800821
// appear in the type of `origFnType`.
801822
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
802823
// take `CaptureInfo` into account.
803-
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
804-
loweredParamIndices = loweredParamIndices->extendingCapacity(
824+
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
825+
silParamIndices = silParamIndices->extendingCapacity(
805826
getASTContext(), origSilFnType->getNumParameters());
806827
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
807-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
828+
SILAutoDiffIndices indices(/*source*/ 0, silParamIndices);
808829

809830
// Self reordering thunk is necessary if wrt at least two parameters,
810831
// including self.
@@ -818,14 +839,22 @@ void SILGenModule::emitDifferentiabilityWitness(
818839
};
819840
bool reorderSelf = shouldReorderSelf();
820841

821-
// Create new SIL differentiability witness.
842+
// Get or create new SIL differentiability witness.
843+
// Differentiability witness already exists when there are two `@derivative`
844+
// attributes (JVP and VJP) for the same derivative function configuration.
822845
// Witness JVP and VJP are set below.
823-
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
824-
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
825-
config.resultIndices, config.derivativeGenericSignature,
826-
/*jvp*/ nullptr, /*vjp*/ nullptr,
827-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
828-
diffAttr);
846+
AutoDiffConfig silConfig(silParamIndices, config.resultIndices,
847+
config.derivativeGenericSignature);
848+
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
849+
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
850+
if (!diffWitness) {
851+
diffWitness = SILDifferentiabilityWitness::createDefinition(
852+
M, originalFunction->getLinkage(), originalFunction,
853+
silConfig.parameterIndices, silConfig.resultIndices,
854+
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
855+
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
856+
attr);
857+
}
829858

830859
// Set derivative function in differentiability witness.
831860
auto setDerivativeInDifferentiabilityWitness =

lib/Sema/TypeCheckAttr.cpp

Lines changed: 29 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3696,6 +3696,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36963696
attr->setInvalid();
36973697
return;
36983698
}
3699+
attr->setDerivativeKind(kind);
36993700
// `value: R` result tuple element must conform to `Differentiable`.
37003701
auto diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
37013702
auto valueResultType = valueResultElt.getType();
@@ -3918,74 +3919,35 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
39183919
return;
39193920
}
39203921

3921-
// Try to find a `@differentiable` attribute on the original function with the
3922-
// same differentiation parameters.
3923-
DifferentiableAttr *da = nullptr;
3924-
for (auto *cda : originalAFD->getAttrs().getAttributes<DifferentiableAttr>())
3925-
if (checkedWrtParamIndices == cda->getParameterIndices())
3926-
da = const_cast<DifferentiableAttr *>(cda);
3927-
// If the original function does not have a `@differentiable` attribute with
3928-
// the same differentiation parameters, create one.
3929-
// TODO(TF-835): Lower `@derivative` attributes directly to SIL
3930-
// differentiability witnesses during SILGen instead of generating implicit
3931-
// `@differentiable` attributes.
3932-
if (!da) {
3933-
da = DifferentiableAttr::create(
3934-
originalAFD, /*implicit*/ true, attr->AtLoc, attr->getRange(),
3935-
/*linear*/ false, checkedWrtParamIndices, /*jvp*/ None,
3936-
/*vjp*/ None, derivative->getGenericSignature());
3937-
switch (kind) {
3938-
case AutoDiffDerivativeFunctionKind::JVP:
3939-
da->setJVPFunction(derivative);
3940-
break;
3941-
case AutoDiffDerivativeFunctionKind::VJP:
3942-
da->setVJPFunction(derivative);
3943-
break;
3944-
}
3945-
auto insertion = Ctx.DifferentiableAttrs.try_emplace(
3946-
{originalAFD, checkedWrtParamIndices}, da);
3947-
// Valid `@differentiable` attributes are uniqued by their parameter
3948-
// indices. Reject duplicate attributes for the same decl and parameter
3949-
// indices pair.
3950-
if (!insertion.second && insertion.first->getSecond() != da) {
3951-
diagnoseAndRemoveAttr(da, diag::differentiable_attr_duplicate);
3952-
diagnose(insertion.first->getSecond()->getLocation(),
3953-
diag::differentiable_attr_duplicate_note);
3954-
return;
3955-
}
3956-
originalAFD->getAttrs().add(da);
3957-
return;
3922+
// Diagnose if there exists a `@differentiable` attribute with the same
3923+
// parameter indices as the `@derivative` function.
3924+
// NOTE(TF-680): relaxing this limitation requires derivative generic
3925+
// signature mangling to avoid name collisions for SIL derivative functions
3926+
// with the same parameter indices but different derivative generic
3927+
// signatures.
3928+
auto *diffAttr =
3929+
Ctx.DifferentiableAttrs.lookup({originalAFD, checkedWrtParamIndices});
3930+
bool isOriginalAFDProtocolRequirement =
3931+
isa<ProtocolDecl>(originalAFD->getDeclContext()) &&
3932+
originalAFD->isProtocolRequirement();
3933+
if (diffAttr && !isOriginalAFDProtocolRequirement) {
3934+
diagnoseAndRemoveAttr(attr,
3935+
diag::derivative_attr_original_already_has_derivative,
3936+
originalAFD->getFullName());
3937+
diagnose(diffAttr->getLocation(), diag::differentiable_attr_duplicate_note);
39583938
}
3959-
// If the original function has a `@differentiable` attribute with the same
3960-
// differentiation parameters, check if the `@differentiable` attribute
3961-
// already has a different registered derivative. If so, emit an error on the
3962-
// `@derivative` attribute. Otherwise, register the derivative in the
3963-
// `@differentiable` attribute.
3964-
switch (kind) {
3965-
case AutoDiffDerivativeFunctionKind::JVP:
3966-
// If there's a different registered derivative, emit an error.
3967-
if ((da->getJVP() &&
3968-
da->getJVP()->Name.getBaseName() != derivative->getBaseName()) ||
3969-
(da->getJVPFunction() && da->getJVPFunction() != derivative)) {
3970-
diagnoseAndRemoveAttr(
3971-
attr, diag::derivative_attr_original_already_has_derivative,
3972-
originalAFD->getFullName());
3973-
return;
3974-
}
3975-
da->setJVPFunction(derivative);
3976-
break;
3977-
case AutoDiffDerivativeFunctionKind::VJP:
3978-
// If there's a different registered derivative, emit an error.
3979-
if ((da->getVJP() &&
3980-
da->getVJP()->Name.getBaseName() != derivative->getBaseName()) ||
3981-
(da->getVJPFunction() && da->getVJPFunction() != derivative)) {
3982-
diagnoseAndRemoveAttr(
3983-
attr, diag::derivative_attr_original_already_has_derivative,
3984-
originalAFD->getFullName());
3985-
return;
3986-
}
3987-
da->setVJPFunction(derivative);
3988-
break;
3939+
3940+
// Valid `@derivative` attributes are uniqued by original function and
3941+
// parameter indices. Reject duplicate attributes.
3942+
auto insertion = Ctx.DerivativeAttrs.try_emplace(
3943+
{originalAFD, checkedWrtParamIndices, kind}, attr);
3944+
if (!insertion.second) {
3945+
diagnoseAndRemoveAttr(attr,
3946+
diag::derivative_attr_original_already_has_derivative,
3947+
originalAFD->getFullName());
3948+
diagnose(insertion.first->getSecond()->getLocation(),
3949+
diag::differentiable_attr_duplicate_note);
3950+
return;
39893951
}
39903952

39913953
// Register derivative function configuration.

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -533,55 +533,61 @@ swift::matchWitness(
533533
// SWIFT_ENABLE_TENSORFLOW
534534
auto result = finalize(anyRenaming, optionalAdjustments);
535535
if (result.isViable()) {
536-
// '@differentiable' attributes must match completely. If there exists a
537-
// '@differentiable' attribute with a superset of the "wrt" parameters of
538-
// a requirement, then an '@differentiable' attribute is added
539-
// automatically.
536+
// For all `@differentiable` attributes of the protocol requirement, check
537+
// that the witness has a derivative configuration with exactly the same
538+
// parameter indices, or one with "superset" parameter indices. If there
539+
// exists a witness derivative configuration with "superset" parameter
540+
// indices, create an implicit `@differentiable` attribute for the witness
541+
// with the exact parameter indices from the requirement `@differentiable`
542+
// attribute.
540543
ASTContext &ctx = witness->getASTContext();
541-
auto witnessDiffAttrs = witnessAttrs
542-
.getAttributes<DifferentiableAttr, /*AllowInvalid*/ true>();
544+
auto *witnessAFD = dyn_cast<AbstractFunctionDecl>(witness);
545+
if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness))
546+
witnessAFD = witnessASD->getAccessor(AccessorKind::Get);
543547
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
544-
// TODO(TF-482): Also check whether generic requirements are the same.
545-
bool reqDiffAttrMatch = llvm::any_of(
546-
witnessDiffAttrs, [&](const DifferentiableAttr *witnessDiffAttr) {
547-
return witnessDiffAttr->getParameterIndices() &&
548-
reqDiffAttr->getParameterIndices() &&
549-
witnessDiffAttr->parametersMatch(*reqDiffAttr);
550-
});
551-
bool reqDiffAttrSupersetMatch = llvm::any_of(
552-
witnessDiffAttrs, [&](const DifferentiableAttr *witnessDiffAttr) {
553-
return witnessDiffAttr->getParameterIndices() &&
554-
reqDiffAttr->getParameterIndices() &&
555-
witnessDiffAttr->getParameterIndices()
556-
->isSupersetOf(reqDiffAttr->getParameterIndices());
557-
});
558-
if (!reqDiffAttrMatch) {
559-
auto implicitDiffAttr = false;
560-
if (reqDiffAttrSupersetMatch) {
561-
auto *witnessAFD = cast<AbstractFunctionDecl>(witness);
548+
bool foundExactAttr = false;
549+
bool foundSupersetAttr = false;
550+
for (auto witnessConfig :
551+
witnessAFD->getDerivativeFunctionConfigurations()) {
552+
if (witnessConfig.parameterIndices ==
553+
reqDiffAttr->getParameterIndices())
554+
foundExactAttr = true;
555+
if (witnessConfig.parameterIndices->isSupersetOf(
556+
reqDiffAttr->getParameterIndices()))
557+
foundSupersetAttr = true;
558+
}
559+
if (!foundExactAttr) {
560+
bool success = false;
561+
if (foundSupersetAttr) {
562+
// If the witness has a "superset" derivative configuration, create an
563+
// implicit `@differentiable` attribute with the exact requirement
564+
// `@differentiable` attribute parameter indices.
565+
// TODO(TF-1041): Investigate why this logic is necessary. When
566+
// "implicit `@differentiable` attribute" logic is removed, core
567+
// stdlib compilation succeeds and AutoDiff tests pass, but TensorFlow
568+
// compilation crashes. An AutoDiff reproducer test should be added.
562569
auto *newAttr = DifferentiableAttr::create(
563570
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
564571
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
565572
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
566573
/*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature());
567574
auto insertion = ctx.DifferentiableAttrs.try_emplace(
568575
{witnessAFD, newAttr->getParameterIndices()}, newAttr);
569-
// Register derivative function configuration.
570-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
571-
witnessAFD->addDerivativeFunctionConfiguration(
572-
{newAttr->getParameterIndices(), resultIndices,
573-
newAttr->getDerivativeGenericSignature()});
574-
// Valid `@differentiable` attributes are uniqued by their parameter
575-
// indices. Reject duplicate attributes for the same decl and parameter
576-
// indices pair.
576+
// Valid `@differentiable` attributes are uniqued by original function
577+
// and parameter indices. Reject duplicate attributes.
577578
if (!insertion.second) {
578579
newAttr->setInvalid();
579580
} else {
580581
witness->getAttrs().add(newAttr);
581-
implicitDiffAttr = true;
582+
// Register derivative function configuration.
583+
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
584+
witnessAFD->addDerivativeFunctionConfiguration(
585+
{newAttr->getParameterIndices(), resultIndices,
586+
newAttr->getDerivativeGenericSignature()});
587+
success = true;
582588
}
583589
}
584-
if (!implicitDiffAttr) {
590+
if (!success) {
585591
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
586592
return RequirementMatch(
587593
getStandinForAccessor(vdWitness, AccessorKind::Get),

0 commit comments

Comments
 (0)