Skip to content

Commit 38bc954

Browse files
authored
[AutoDiff] Store original declaration in DifferentiableAttr. (#27985)
Store original `Decl *` in `DifferentiableAttr`. This is important for requestifying `DifferentiableAttr->getParameterIndices()`: we want the ability to resolve parameter indices without needing to pass the original `AbstractFunctionDecl` to `getParameterIndices`. Add round-trip `@differentiable` attribute AST serialization test. `@differentiable` attribute type-checking and serialization assert that the original declaration is set.
1 parent e4c8c8e commit 38bc954

File tree

9 files changed

+133
-28
lines changed

9 files changed

+133
-28
lines changed

include/swift/AST/Attr.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,8 @@ class DifferentiableAttr final
15411541
ParsedAutoDiffParameter> {
15421542
friend TrailingObjects;
15431543

1544+
/// The declaration on which the `@differentiable` attribute is declared.
1545+
Decl *OriginalDeclaration = nullptr;
15441546
/// Whether this function is linear.
15451547
bool Linear;
15461548
/// The number of parsed parameters specified in 'wrt:'.
@@ -1573,7 +1575,7 @@ class DifferentiableAttr final
15731575
Optional<DeclNameWithLoc> vjp,
15741576
TrailingWhereClause *clause);
15751577

1576-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1578+
explicit DifferentiableAttr(Decl *original, bool implicit,
15771579
SourceLoc atLoc, SourceRange baseRange,
15781580
bool linear, IndexSubset *indices,
15791581
Optional<DeclNameWithLoc> jvp,
@@ -1589,13 +1591,16 @@ class DifferentiableAttr final
15891591
Optional<DeclNameWithLoc> vjp,
15901592
TrailingWhereClause *clause);
15911593

1592-
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1594+
static DifferentiableAttr *create(Decl *original, bool implicit,
15931595
SourceLoc atLoc, SourceRange baseRange,
15941596
bool linear, IndexSubset *indices,
15951597
Optional<DeclNameWithLoc> jvp,
15961598
Optional<DeclNameWithLoc> vjp,
15971599
GenericSignature derivativeGenSig);
15981600

1601+
Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
1602+
void setOriginalDeclaration(Decl *decl);
1603+
15991604
/// Get the optional 'jvp:' function name and location.
16001605
/// Use this instead of `getJVPFunction` to check whether the attribute has a
16011606
/// registered JVP.

lib/AST/Attr.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,16 +1454,16 @@ DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
14541454
getTrailingObjects<ParsedAutoDiffParameter>());
14551455
}
14561456

1457-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1457+
DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
14581458
SourceLoc atLoc, SourceRange baseRange,
1459-
bool linear,
1460-
IndexSubset *indices,
1459+
bool linear, IndexSubset *indices,
14611460
Optional<DeclNameWithLoc> jvp,
14621461
Optional<DeclNameWithLoc> vjp,
14631462
GenericSignature derivativeGenSig)
14641463
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
14651464
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
14661465
ParameterIndices(indices) {
1466+
setOriginalDeclaration(original);
14671467
setDerivativeGenericSignature(derivativeGenSig);
14681468
}
14691469

@@ -1483,19 +1483,26 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
14831483
}
14841484

14851485
DifferentiableAttr *
1486-
DifferentiableAttr::create(ASTContext &context, bool implicit,
1487-
SourceLoc atLoc, SourceRange baseRange,
1488-
bool linear, IndexSubset *indices,
1489-
Optional<DeclNameWithLoc> jvp,
1486+
DifferentiableAttr::create(Decl *original, bool implicit, SourceLoc atLoc,
1487+
SourceRange baseRange, bool linear,
1488+
IndexSubset *indices, Optional<DeclNameWithLoc> jvp,
14901489
Optional<DeclNameWithLoc> vjp,
14911490
GenericSignature derivativeGenSig) {
1492-
void *mem = context.Allocate(sizeof(DifferentiableAttr),
1493-
alignof(DifferentiableAttr));
1494-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1491+
auto &ctx = original->getASTContext();
1492+
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
1493+
alignof(DifferentiableAttr));
1494+
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
14951495
linear, indices, std::move(jvp),
14961496
std::move(vjp), derivativeGenSig);
14971497
}
14981498

1499+
void DifferentiableAttr::setOriginalDeclaration(Decl *decl) {
1500+
assert(decl && "Original declaration must be non-null");
1501+
assert(!OriginalDeclaration &&
1502+
"Original declaration cannot have already been set");
1503+
OriginalDeclaration = decl;
1504+
}
1505+
14991506
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
15001507
JVPFunction = decl;
15011508
if (decl && !JVP)

lib/Parse/ParseDecl.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3272,6 +3272,13 @@ void Parser::delayParseFromBeginningToHere(ParserPosition BeginParserPosition,
32723272
consumeToken();
32733273
}
32743274

3275+
// SWIFT_ENABLE_TENSORFLOW
3276+
static void setOriginalFunctionInDifferentiableAttributes(
3277+
DeclAttributes Attributes, Decl *D) {
3278+
for (auto *attr : Attributes.getAttributes<DifferentiableAttr>())
3279+
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
3280+
}
3281+
32753282
/// Parse a single syntactic declaration and return a list of decl
32763283
/// ASTs. This can return multiple results for var decls that bind to multiple
32773284
/// values, structs that define a struct decl and a constructor, etc.
@@ -3679,6 +3686,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
36793686
Decl *D = DeclResult.get();
36803687
if (!declWasHandledAlready(D)) {
36813688
Handler(D);
3689+
// SWIFT_ENABLE_TENSORFLOW
36823690
if (auto FD = dyn_cast<FuncDecl>(D)) {
36833691
if (auto attr = D->getAttrs().getAttribute<QuotedAttr>()) {
36843692
// TODO(TF-718): Properly mangle names for quote decls.
@@ -3716,7 +3724,11 @@ Parser::parseDecl(ParseDeclOptions Flags,
37163724
Handler(quoteDecl);
37173725
}
37183726
}
3727+
// SWIFT_ENABLE_TENSORFLOW END
37193728
}
3729+
// SWIFT_ENABLE_TENSORFLOW
3730+
setOriginalFunctionInDifferentiableAttributes(D->getAttrs(), D);
3731+
// SWIFT_ENABLE_TENSORFLOW END
37203732
}
37213733

37223734
if (!DeclResult.isParseError()) {
@@ -5513,6 +5525,12 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,
55135525

55145526
accessors.record(*this, PrimaryVar, Invalid);
55155527

5528+
// SWIFT_ENABLE_TENSORFLOW
5529+
for (auto *accessor : accessors.Accessors)
5530+
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
5531+
accessor);
5532+
// SWIFT_ENABLE_TENSORFLOW END
5533+
55165534
return makeParserResult(PrimaryVar);
55175535
}
55185536

@@ -5773,6 +5791,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
57735791
pattern->forEachVariable([&](VarDecl *VD) {
57745792
VD->setStatic(StaticLoc.isValid());
57755793
VD->getAttrs() = Attributes;
5794+
// SWIFT_ENABLE_TENSORFLOW
5795+
setOriginalFunctionInDifferentiableAttributes(Attributes, VD);
5796+
// SWIFT_ENABLE_TENSORFLOW END
57765797
setLocalDiscriminator(VD);
57775798
Decls.push_back(VD);
57785799
if (hasOpaqueReturnTy && sf) {
@@ -7025,6 +7046,12 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
70257046

70267047
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));
70277048

7049+
// SWIFT_ENABLE_TENSORFLOW
7050+
for (auto *accessor : accessors.Accessors)
7051+
setOriginalFunctionInDifferentiableAttributes(accessor->getAttrs(),
7052+
accessor);
7053+
// SWIFT_ENABLE_TENSORFLOW END
7054+
70287055
// No need to setLocalDiscriminator because subscripts cannot
70297056
// validly appear outside of type decls.
70307057
return makeParserResult(Status, Subscript);

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
641641
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
642642
derivativeGenSig = extDecl->getGenericSignature();
643643
auto *diffableAttr = DifferentiableAttr::create(
644-
C, /*implicit*/ true, SourceLoc(), SourceLoc(),
645-
/*linear*/ false, {}, None, None, derivativeGenSig);
644+
member->getAccessor(AccessorKind::Get), /*implicit*/ true,
645+
SourceLoc(), SourceLoc(), /*linear*/ false, {}, None, None,
646+
derivativeGenSig);
646647
member->getAttrs().add(diffableAttr);
647648
// Set getter `@differentiable` attribute parameter indices.
648649
diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0}));

lib/Sema/TypeCheckAttr.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3252,6 +3252,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
32523252
return;
32533253
}
32543254

3255+
assert(attr->getOriginalDeclaration() &&
3256+
"`@differentiable` attribute should have original declaration set "
3257+
"during construction or parsing");
32553258
TC.resolveDeclSignature(original);
32563259
auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>();
32573260
bool isMethod = original->hasImplicitSelfDecl();
@@ -3508,15 +3511,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
35083511
D->getAttrs().removeAttribute(attr);
35093512
// Transfer `@differentiable` attribute from storage declaration to
35103513
// getter accessor.
3514+
auto *getterDecl = asd->getAccessor(AccessorKind::Get);
35113515
auto *newAttr = DifferentiableAttr::create(
3512-
ctx, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(),
3513-
attr->getParameterIndices(), attr->getJVP(), attr->getVJP(),
3514-
attr->getDerivativeGenericSignature());
3516+
getterDecl, /*implicit*/ true, attr->AtLoc, attr->getRange(),
3517+
attr->isLinear(), attr->getParameterIndices(), attr->getJVP(),
3518+
attr->getVJP(), attr->getDerivativeGenericSignature());
35153519
newAttr->setJVPFunction(attr->getJVPFunction());
35163520
newAttr->setVJPFunction(attr->getVJPFunction());
35173521
auto insertion = ctx.DifferentiableAttrs.try_emplace(
3518-
{asd->getAccessor(AccessorKind::Get), newAttr->getParameterIndices()},
3519-
newAttr);
3522+
{getterDecl, newAttr->getParameterIndices()}, newAttr);
35203523
// Valid `@differentiable` attributes are uniqued by their parameter
35213524
// indices. Reject duplicate attributes for the same decl and parameter
35223525
// indices pair.
@@ -3526,7 +3529,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
35263529
diag::differentiable_attr_duplicate_note);
35273530
return;
35283531
}
3529-
asd->getAccessor(AccessorKind::Get)->getAttrs().add(newAttr);
3532+
getterDecl->getAttrs().add(newAttr);
35303533
return;
35313534
}
35323535
auto insertion = ctx.DifferentiableAttrs.try_emplace(
@@ -3805,7 +3808,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
38053808
// If the original function does not have a `@differentiable` attribute with
38063809
// the same differentiation parameters, create one.
38073810
if (!da) {
3808-
da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc,
3811+
da = DifferentiableAttr::create(originalFn, /*implicit*/ true, attr->AtLoc,
38093812
attr->getRange(), attr->isLinear(),
38103813
checkedWrtParamIndices, /*jvp*/ None,
38113814
/*vjp*/ None,

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,9 @@ swift::matchWitness(
567567
if (!reqDiffAttrMatch) {
568568
auto implicitDiffAttr = false;
569569
if (reqDiffAttrSupersetMatch) {
570+
auto *witnessAFD = cast<AbstractFunctionDecl>(witness);
570571
auto *newAttr = DifferentiableAttr::create(
571-
ctx, /*implicit*/ true, reqDiffAttr->AtLoc,
572+
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
572573
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
573574
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
574575
/*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature());

lib/Serialization/Deserialization.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,24 @@ static bool attributeChainContains(DeclAttribute *attr) {
21662166
return tempAttrs.hasAttribute<DERIVED>();
21672167
}
21682168

2169+
// SWIFT_ENABLE_TENSORFLOW
2170+
// Set original declaration in `@differentiable` attributes.
2171+
//
2172+
// Serializing/deserializing the original declaration DeclID in
2173+
// `@differentiable` attributes does not work because it causes
2174+
// `@differentiable` attribute deserialization to enter an infinite loop.
2175+
//
2176+
// Instead, call this ad-hoc function after deserializing a declaration to set
2177+
// it as the original declaration in its `@differentiable` attributes.
2178+
static void setOriginalDeclarationInDifferentiableAttributes(
2179+
Decl *decl, DeclAttribute *attrs) {
2180+
DeclAttributes tempAttrs;
2181+
tempAttrs.setRawAttributeChain(attrs);
2182+
for (auto *attr : tempAttrs.getAttributes<DifferentiableAttr>())
2183+
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(decl);
2184+
}
2185+
// SWIFT_ENABLE_TENSORFLOW END
2186+
21692187
Decl *ModuleFile::getDecl(DeclID DID) {
21702188
Expected<Decl *> deserialized = getDeclChecked(DID);
21712189
if (!deserialized) {
@@ -4084,10 +4102,11 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
40844102
parametersBitVector[i] = parameters[i];
40854103
auto *indices = IndexSubset::get(ctx, parametersBitVector);
40864104

4087-
auto diffAttr =
4088-
DifferentiableAttr::create(ctx, isImplicit, SourceLoc(),
4089-
SourceRange(), linear, indices, jvp, vjp,
4090-
derivativeGenSig);
4105+
auto *diffAttr = DifferentiableAttr::create(
4106+
ctx, isImplicit, SourceLoc(), SourceRange(), linear,
4107+
/*parsedParameters*/ {}, jvp, vjp, /*trailingWhereClause*/ nullptr);
4108+
diffAttr->setParameterIndices(indices);
4109+
diffAttr->setDerivativeGenericSignature(derivativeGenSig);
40914110
diffAttr->setJVPFunction(jvpDecl);
40924111
diffAttr->setVJPFunction(vjpDecl);
40934112
Attr = diffAttr;
@@ -4237,9 +4256,16 @@ DeclDeserializer::getDeclCheckedImpl() {
42374256
&MF, declOrOffset, static_cast<decls_block::RecordKind>(recordID));
42384257

42394258
switch (recordID) {
4259+
// SWIFT_ENABLE_TENSORFLOW
4260+
// Set original declaration in `@differentiable` attributes.
42404261
#define CASE(RECORD_NAME) \
4241-
case decls_block::RECORD_NAME##Layout::Code: \
4242-
return deserialize##RECORD_NAME(scratch, blobData);
4262+
case decls_block::RECORD_NAME##Layout::Code: {\
4263+
auto decl = deserialize##RECORD_NAME(scratch, blobData); \
4264+
if (decl) \
4265+
setOriginalDeclarationInDifferentiableAttributes(decl.get(), DAttrs); \
4266+
return decl; \
4267+
}
4268+
// SWIFT_ENABLE_TENSORFLOW END
42434269

42444270
CASE(TypeAlias)
42454271
CASE(GenericTypeParamDecl)

lib/Serialization/Serialization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2309,6 +2309,9 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
23092309
case DAK_Differentiable: {
23102310
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
23112311
auto *attr = cast<DifferentiableAttr>(DA);
2312+
assert(attr->getOriginalDeclaration() &&
2313+
"`@differentiable` attribute should have original declaration set "
2314+
"during construction or parsing");
23122315

23132316
IdentifierID jvpName = 0;
23142317
DeclID jvpRef = 0;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -emit-module %s -o %t/differentiable_attr_serialization.swiftmodule
3+
// RUN: %target-swift-frontend -merge-modules -sil-merge-partial-modules -emit-module %t/differentiable_attr_serialization.swiftmodule
4+
5+
// Test round-trip `@differentiable` attribute AST serialization.
6+
7+
// Motivation: check that `@differentiable` attributes always have original
8+
// declaration set.
9+
10+
struct Foo: Differentiable {
11+
@differentiable
12+
func method() -> Self { self }
13+
14+
@differentiable
15+
init(_ x: Float) {}
16+
17+
@differentiable
18+
var computedProperty: Float { 1 }
19+
20+
var computedPropertyGetter: Float {
21+
@differentiable
22+
get { 1 }
23+
}
24+
25+
@differentiable
26+
subscript() -> Float { 1 }
27+
28+
subscript(_ x: Float) -> Float {
29+
@differentiable
30+
get { 1 }
31+
}
32+
}

0 commit comments

Comments
 (0)