Skip to content

Commit 9c22762

Browse files
authored
Enable propagation of @differentiable attribute from storage declarations to setters. (#63988)
Fixes #63169 and TF-129
1 parent 05ff74e commit 9c22762

File tree

5 files changed

+180
-122
lines changed

5 files changed

+180
-122
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2794,7 +2794,7 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
27942794
auto adjSelf = getAdjointBuffer(origEntry, origSelf);
27952795
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
27962796
// Switch based on the property's value category.
2797-
switch (origArg->getType().getCategory()) {
2797+
switch (getTangentValueCategory(origArg)) {
27982798
case SILValueCategory::Object: {
27992799
auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt,
28002800
LoadOwnershipQualifier::Take);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 166 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -5214,49 +5214,6 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
52145214
return originalType;
52155215
}
52165216

5217-
/// Given a `@differentiable` attribute, attempts to resolve the original
5218-
/// `AbstractFunctionDecl` for which it is registered, using the declaration
5219-
/// on which it is actually declared. On error, emits diagnostic and returns
5220-
/// `nullptr`.
5221-
AbstractFunctionDecl *
5222-
resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
5223-
auto *D = attr->getOriginalDeclaration();
5224-
assert(D &&
5225-
"Original declaration should be resolved by parsing/deserialization");
5226-
auto *original = dyn_cast<AbstractFunctionDecl>(D);
5227-
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
5228-
// If `@differentiable` attribute is declared directly on a
5229-
// `AbstractStorageDecl` (a stored/computed property or subscript),
5230-
// forward the attribute to the storage's getter.
5231-
// TODO(TF-129): Forward `@differentiable` attributes to setters after
5232-
// differentiation supports inout parameters.
5233-
// TODO(TF-1080): Forward `@differentiable` attributes to `read` and
5234-
// `modify` accessors after differentiation supports `inout` parameters.
5235-
if (!asd->getDeclContext()->isModuleScopeContext()) {
5236-
original = asd->getSynthesizedAccessor(AccessorKind::Get);
5237-
} else {
5238-
original = nullptr;
5239-
}
5240-
}
5241-
// Non-`get` accessors are not yet supported: `set`, `read`, and `modify`.
5242-
// TODO(TF-1080): Enable `read` and `modify` when differentiation supports
5243-
// coroutines.
5244-
if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
5245-
if (!accessor->isGetter() && !accessor->isSetter())
5246-
original = nullptr;
5247-
// Diagnose if original `AbstractFunctionDecl` could not be resolved.
5248-
if (!original) {
5249-
diagnoseAndRemoveAttr(D, attr, diag::invalid_decl_attribute, attr);
5250-
attr->setInvalid();
5251-
return nullptr;
5252-
}
5253-
// If the original function has an error interface type, return.
5254-
// A diagnostic should have already been emitted.
5255-
if (original->getInterfaceType()->hasError())
5256-
return nullptr;
5257-
return original;
5258-
}
5259-
52605217
/// Given a `@differentiable` attribute, attempts to resolve the derivative
52615218
/// generic signature. The derivative generic signature is returned as
52625219
/// `derivativeGenSig`. On error, emits diagnostic, assigns `nullptr` to
@@ -5435,11 +5392,11 @@ bool resolveDifferentiableAttrDifferentiabilityParameters(
54355392

54365393
/// Checks whether differentiable programming is enabled for the given
54375394
/// differentiation-related attribute. Returns true on error.
5438-
bool checkIfDifferentiableProgrammingEnabled(ASTContext &ctx,
5439-
DeclAttribute *attr,
5440-
DeclContext *DC) {
5395+
static bool checkIfDifferentiableProgrammingEnabled(DeclAttribute *attr,
5396+
Decl *D) {
5397+
auto &ctx = D->getASTContext();
54415398
auto &diags = ctx.Diags;
5442-
auto *SF = DC->getParentSourceFile();
5399+
auto *SF = D->getDeclContext()->getParentSourceFile();
54435400
assert(SF && "Source file not found");
54445401
// The `Differentiable` protocol must be available.
54455402
// If unavailable, the `_Differentiation` module should be imported.
@@ -5452,31 +5409,36 @@ bool checkIfDifferentiableProgrammingEnabled(ASTContext &ctx,
54525409
return true;
54535410
}
54545411

5455-
IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
5456-
Evaluator &evaluator, DifferentiableAttr *attr) const {
5457-
// Skip type-checking for implicit `@differentiable` attributes. We currently
5458-
// assume that all implicit `@differentiable` attributes are valid.
5459-
//
5460-
// Motivation: some implicit attributes do not have a `where` clause, and this
5461-
// function assumes that the `where` clauses exist. Propagating `where`
5462-
// clauses and requirements consistently is a larger problem, to be revisited.
5463-
if (attr->isImplicit())
5464-
return nullptr;
5412+
static IndexSubset *
5413+
resolveDiffParamIndices(AbstractFunctionDecl *original,
5414+
DifferentiableAttr *attr,
5415+
GenericSignature derivativeGenSig) {
5416+
auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment();
54655417

5466-
auto *D = attr->getOriginalDeclaration();
5467-
auto &ctx = D->getASTContext();
5468-
auto &diags = ctx.Diags;
5469-
// `@differentiable` attribute requires experimental differentiable
5470-
// programming to be enabled.
5471-
if (checkIfDifferentiableProgrammingEnabled(ctx, attr, D->getDeclContext()))
5472-
return nullptr;
5418+
// Compute the derivative function type.
5419+
auto originalFnRemappedTy = original->getInterfaceType()->castTo<AnyFunctionType>();
5420+
if (derivativeGenEnv)
5421+
originalFnRemappedTy =
5422+
derivativeGenEnv->mapTypeIntoContext(originalFnRemappedTy)
5423+
->castTo<AnyFunctionType>();
54735424

5474-
// Resolve the original `AbstractFunctionDecl`.
5475-
auto *original = resolveDifferentiableAttrOriginalFunction(attr);
5476-
if (!original)
5425+
// Resolve and validate the differentiability parameters.
5426+
IndexSubset *resolvedDiffParamIndices = nullptr;
5427+
if (resolveDifferentiableAttrDifferentiabilityParameters(
5428+
attr, original, originalFnRemappedTy, derivativeGenEnv,
5429+
resolvedDiffParamIndices))
54775430
return nullptr;
54785431

5479-
auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>();
5432+
return resolvedDiffParamIndices;
5433+
}
5434+
5435+
5436+
static IndexSubset *
5437+
typecheckDifferentiableAttrforDecl(AbstractFunctionDecl *original,
5438+
DifferentiableAttr *attr,
5439+
IndexSubset *resolvedDiffParamIndices = nullptr) {
5440+
auto &ctx = original->getASTContext();
5441+
auto &diags = ctx.Diags;
54805442

54815443
// Diagnose if original function has opaque result types.
54825444
if (auto *opaqueResultTypeDecl = original->getOpaqueResultTypeDecl()) {
@@ -5523,69 +5485,161 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
55235485
}
55245486

55255487
// Resolve the derivative generic signature.
5526-
GenericSignature derivativeGenSig = nullptr;
5527-
if (resolveDifferentiableAttrDerivativeGenericSignature(attr, original,
5488+
GenericSignature derivativeGenSig = attr->getDerivativeGenericSignature();
5489+
if (!derivativeGenSig &&
5490+
resolveDifferentiableAttrDerivativeGenericSignature(attr, original,
55285491
derivativeGenSig))
55295492
return nullptr;
5530-
auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment();
5531-
5532-
// Compute the derivative function type.
5533-
auto originalFnRemappedTy = originalFnTy;
5534-
if (derivativeGenEnv)
5535-
originalFnRemappedTy =
5536-
derivativeGenEnv->mapTypeIntoContext(originalFnRemappedTy)
5537-
->castTo<AnyFunctionType>();
55385493

55395494
// Resolve and validate the differentiability parameters.
5540-
IndexSubset *resolvedDiffParamIndices = nullptr;
5541-
if (resolveDifferentiableAttrDifferentiabilityParameters(
5542-
attr, original, originalFnRemappedTy, derivativeGenEnv,
5543-
resolvedDiffParamIndices))
5495+
if (!resolvedDiffParamIndices)
5496+
resolvedDiffParamIndices = resolveDiffParamIndices(original, attr,
5497+
derivativeGenSig);
5498+
if (!resolvedDiffParamIndices)
55445499
return nullptr;
55455500

5546-
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
5547-
// Remove `@differentiable` attribute from storage declaration to prevent
5548-
// duplicate attribute registration during SILGen.
5549-
D->getAttrs().removeAttribute(attr);
5550-
// Transfer `@differentiable` attribute from storage declaration to
5551-
// getter accessor.
5552-
auto *getterDecl = asd->getOpaqueAccessor(AccessorKind::Get);
5553-
auto *newAttr = DifferentiableAttr::create(
5554-
getterDecl, /*implicit*/ true, attr->AtLoc, attr->getRange(),
5555-
attr->getDifferentiabilityKind(), resolvedDiffParamIndices,
5556-
attr->getDerivativeGenericSignature());
5557-
auto insertion = ctx.DifferentiableAttrs.try_emplace(
5558-
{getterDecl, resolvedDiffParamIndices}, newAttr);
5559-
// Reject duplicate `@differentiable` attributes.
5560-
if (!insertion.second) {
5561-
diagnoseAndRemoveAttr(D, attr, diag::differentiable_attr_duplicate);
5562-
diags.diagnose(insertion.first->getSecond()->getLocation(),
5563-
diag::differentiable_attr_duplicate_note);
5564-
return nullptr;
5565-
}
5566-
getterDecl->getAttrs().add(newAttr);
5567-
// Register derivative function configuration.
5568-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
5569-
getterDecl->addDerivativeFunctionConfiguration(
5570-
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
5571-
return resolvedDiffParamIndices;
5572-
}
55735501
// Reject duplicate `@differentiable` attributes.
55745502
auto insertion =
5575-
ctx.DifferentiableAttrs.try_emplace({D, resolvedDiffParamIndices}, attr);
5503+
ctx.DifferentiableAttrs.try_emplace({original, resolvedDiffParamIndices}, attr);
55765504
if (!insertion.second && insertion.first->getSecond() != attr) {
5577-
diagnoseAndRemoveAttr(D, attr, diag::differentiable_attr_duplicate);
5505+
diagnoseAndRemoveAttr(original, attr, diag::differentiable_attr_duplicate);
55785506
diags.diagnose(insertion.first->getSecond()->getLocation(),
55795507
diag::differentiable_attr_duplicate_note);
55805508
return nullptr;
55815509
}
5510+
55825511
// Register derivative function configuration.
55835512
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
55845513
original->addDerivativeFunctionConfiguration(
55855514
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
55865515
return resolvedDiffParamIndices;
55875516
}
55885517

5518+
/// Given a `@differentiable` attribute, attempts to resolve the original
5519+
/// `AbstractFunctionDecl` for which it is registered, using the declaration
5520+
/// on which it is actually declared. On error, emits diagnostic and returns
5521+
/// `nullptr`.
5522+
static AbstractFunctionDecl *
5523+
resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
5524+
auto *D = attr->getOriginalDeclaration();
5525+
auto *original = dyn_cast<AbstractFunctionDecl>(D);
5526+
5527+
// Non-`get`/`set` accessors are not yet supported: `read`, and `modify`.
5528+
// TODO(TF-1080): Enable `read` and `modify` when differentiation supports
5529+
// coroutines.
5530+
if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
5531+
if (!accessor->isGetter() && !accessor->isSetter())
5532+
original = nullptr;
5533+
5534+
// Diagnose if original `AbstractFunctionDecl` could not be resolved.
5535+
if (!original) {
5536+
diagnoseAndRemoveAttr(D, attr, diag::invalid_decl_attribute, attr);
5537+
attr->setInvalid();
5538+
return nullptr;
5539+
}
5540+
5541+
// If the original function has an error interface type, return.
5542+
// A diagnostic should have already been emitted.
5543+
if (original->getInterfaceType()->hasError())
5544+
return nullptr;
5545+
5546+
return original;
5547+
}
5548+
5549+
static IndexSubset *
5550+
resolveDifferentiableAccessors(DifferentiableAttr *attr,
5551+
AbstractStorageDecl *asd) {
5552+
auto typecheckAccessor = [&](AccessorDecl *ad) -> IndexSubset* {
5553+
GenericSignature derivativeGenSig = nullptr;
5554+
if (resolveDifferentiableAttrDerivativeGenericSignature(attr, ad,
5555+
derivativeGenSig))
5556+
return nullptr;
5557+
5558+
IndexSubset *resolvedDiffParamIndices = resolveDiffParamIndices(ad, attr,
5559+
derivativeGenSig);
5560+
if (!resolvedDiffParamIndices)
5561+
return nullptr;
5562+
5563+
auto *newAttr = DifferentiableAttr::create(
5564+
ad, /*implicit*/ true, attr->AtLoc, attr->getRange(),
5565+
attr->getDifferentiabilityKind(), resolvedDiffParamIndices,
5566+
attr->getDerivativeGenericSignature());
5567+
ad->getAttrs().add(newAttr);
5568+
5569+
if (!typecheckDifferentiableAttrforDecl(ad, attr,
5570+
resolvedDiffParamIndices))
5571+
return nullptr;
5572+
5573+
return resolvedDiffParamIndices;
5574+
};
5575+
5576+
// No getters / setters for global variables
5577+
if (asd->getDeclContext()->isModuleScopeContext()) {
5578+
diagnoseAndRemoveAttr(asd, attr, diag::invalid_decl_attribute, attr);
5579+
attr->setInvalid();
5580+
return nullptr;
5581+
}
5582+
5583+
if (!typecheckAccessor(asd->getSynthesizedAccessor(AccessorKind::Get)))
5584+
return nullptr;
5585+
5586+
if (asd->supportsMutation()) {
5587+
// FIXME: Class-typed values have reference semantics and can be freely
5588+
// mutated. Thus, they should be treated like inout parameters for the
5589+
// purposes of @differentiable and @derivative type-checking. Until
5590+
// https://github.com/apple/swift/issues/55542 is fixed, check if setter has
5591+
// computed semantic results and do not typecheck if they are none
5592+
// (class-typed `self' parameter is not treated as a "semantic result"
5593+
// currently)
5594+
if (!asd->getDeclContext()->getSelfClassDecl())
5595+
if (!typecheckAccessor(asd->getSynthesizedAccessor(AccessorKind::Set)))
5596+
return nullptr;
5597+
}
5598+
5599+
// Remove `@differentiable` attribute from storage declaration to prevent
5600+
// duplicate attribute registration during SILGen.
5601+
asd->getAttrs().removeAttribute(attr);
5602+
5603+
// Here we are effectively removing attribute from original decl, therefore no
5604+
// index subset for us
5605+
return nullptr;
5606+
}
5607+
5608+
5609+
IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
5610+
Evaluator &evaluator, DifferentiableAttr *attr) const {
5611+
// Skip type-checking for implicit `@differentiable` attributes. We currently
5612+
// assume that all implicit `@differentiable` attributes are valid.
5613+
//
5614+
// Motivation: some implicit attributes do not have a `where` clause, and this
5615+
// function assumes that the `where` clauses exist. Propagating `where`
5616+
// clauses and requirements consistently is a larger problem, to be revisited.
5617+
if (attr->isImplicit())
5618+
return nullptr;
5619+
5620+
auto *D = attr->getOriginalDeclaration();
5621+
assert(D &&
5622+
"Original declaration should be resolved by parsing/deserialization");
5623+
5624+
// `@differentiable` attribute requires experimental differentiable
5625+
// programming to be enabled.
5626+
if (checkIfDifferentiableProgrammingEnabled(attr, D))
5627+
return nullptr;
5628+
5629+
// If `@differentiable` attribute is declared directly on a
5630+
// `AbstractStorageDecl` (a stored/computed property or subscript),
5631+
// forward the attribute to the storage's getter / setter
5632+
if (auto *asd = dyn_cast<AbstractStorageDecl>(D))
5633+
return resolveDifferentiableAccessors(attr, asd);
5634+
5635+
// Resolve the original `AbstractFunctionDecl`.
5636+
auto *original = resolveDifferentiableAttrOriginalFunction(attr);
5637+
if (!original)
5638+
return nullptr;
5639+
5640+
return typecheckDifferentiableAttrforDecl(original, attr);
5641+
}
5642+
55895643
void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
55905644
// Call `getParameterIndices` to trigger
55915645
// `DifferentiableAttributeTypeCheckRequest`.
@@ -5608,7 +5662,7 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
56085662
auto &diags = Ctx.Diags;
56095663
// `@derivative` attribute requires experimental differentiable programming
56105664
// to be enabled.
5611-
if (checkIfDifferentiableProgrammingEnabled(Ctx, attr, D->getDeclContext()))
5665+
if (checkIfDifferentiableProgrammingEnabled(attr, D))
56125666
return true;
56135667
auto *derivative = cast<FuncDecl>(D);
56145668
auto originalName = attr->getOriginalFunctionName();

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ public struct Tracked<T> {
108108
@differentiable(reverse where T: Differentiable, T == T.TangentVector)
109109
public var value: T {
110110
get { handle.value }
111-
set { handle.value = newValue }
111+
// FIXME: Disable setter until https://github.com/apple/swift/issues/55542 is fixed
112+
// set { handle.value = newValue }
112113
}
113114
}
114115

@@ -141,7 +142,8 @@ public struct NonresilientTracked<T> {
141142
@differentiable(reverse where T: Differentiable, T == T.TangentVector)
142143
public var value: T {
143144
get { handle.value }
144-
set { handle.value = newValue }
145+
// FIXME: Disable setter until https://github.com/apple/swift/issues/55542 is fixed
146+
// set { handle.value = newValue }
145147
}
146148
}
147149

test/AutoDiff/SILGen/witness_table.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,14 @@ struct Struct: Protocol {
114114
// CHECK-NEXT: method #Protocol.property!getter.jvp.S.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> () -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvgTW_jvp_S
115115
// CHECK-NEXT: method #Protocol.property!getter.vjp.S.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> () -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvgTW_vjp_S
116116
// CHECK-NEXT: method #Protocol.property!setter: <Self where Self : Protocol> (inout Self) -> (Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW
117+
// CHECK-NEXT: method #Protocol.property!setter.jvp.SS.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW_jvp_SS
118+
// CHECK-NEXT: method #Protocol.property!setter.vjp.SS.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW_vjp_SS
117119
// CHECK-NEXT: method #Protocol.property!modify: <Self where Self : Protocol> (inout Self) -> () -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvMTW
118120
// CHECK-NEXT: method #Protocol.subscript!getter: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW
119121
// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUU.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SU
120122
// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUU.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU
121123
// CHECK-NEXT: method #Protocol.subscript!setter: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW
124+
// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUU.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUU
125+
// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUU.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUU
122126
// CHECK-NEXT: method #Protocol.subscript!modify: <Self where Self : Protocol> (inout Self) -> (Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftciMTW
123127
// CHECK: }

0 commit comments

Comments
 (0)