Skip to content

Commit 4f7349b

Browse files
committed
[AutoDiff upstream] Store @differentiable original declaration.
Store in `DifferentiableAttr` the original declaration on which the attribute is declared. The original declaration is resolved during parsing and deserialization (not yet upstreamed). Progress towards TF-828: upstream `@differentiable` attribute type-checking.
1 parent 1aa9508 commit 4f7349b

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

include/swift/AST/Attr.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,10 @@ class DifferentiableAttr final
16481648
ParsedAutoDiffParameter> {
16491649
friend TrailingObjects;
16501650

1651+
/// The declaration on which the `@differentiable` attribute is declared.
1652+
/// May not be a valid declaration for `@differentiable` attributes.
1653+
/// Resolved during parsing and deserialization.
1654+
Decl *OriginalDeclaration = nullptr;
16511655
/// Whether this function is linear (optional).
16521656
bool Linear;
16531657
/// The number of parsed differentiability parameters specified in 'wrt:'.
@@ -1703,6 +1707,12 @@ class DifferentiableAttr final
17031707
Optional<DeclNameRefWithLoc> vjp,
17041708
GenericSignature derivativeGenSig);
17051709

1710+
Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
1711+
1712+
/// Sets the original declaration on which this attribute is declared.
1713+
/// Should only be used by parsing and deserialization.
1714+
void setOriginalDeclaration(Decl *originalDeclaration);
1715+
17061716
/// Get the optional 'jvp:' function name and location.
17071717
/// Use this instead of `getJVPFunction` to check whether the attribute has a
17081718
/// registered JVP.
@@ -1755,10 +1765,9 @@ class DifferentiableAttr final
17551765

17561766
// Print the attribute to the given stream.
17571767
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1758-
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1759-
void print(llvm::raw_ostream &OS, const Decl *D,
1760-
bool omitWrtClause = false,
1761-
bool omitAssociatedFunctions = false) const;
1768+
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
1769+
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false,
1770+
bool omitDerivativeFunctions = false) const;
17621771

17631772
static bool classof(const DeclAttribute *DA) {
17641773
return DA->getKind() == DAK_Differentiable;

lib/AST/Attr.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1462,7 +1462,8 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
14621462
Optional<DeclNameRefWithLoc> vjp,
14631463
GenericSignature derivativeGenSig)
14641464
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1465-
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
1465+
OriginalDeclaration(original), Linear(linear), JVP(std::move(jvp)),
1466+
VJP(std::move(vjp)) {
14661467
setParameterIndices(parameterIndices);
14671468
setDerivativeGenericSignature(derivativeGenSig);
14681469
}
@@ -1497,6 +1498,13 @@ DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
14971498
std::move(vjp), derivativeGenSig);
14981499
}
14991500

1501+
void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
1502+
assert(originalDeclaration && "Original declaration must be non-null");
1503+
assert(!OriginalDeclaration &&
1504+
"Original declaration cannot have already been set");
1505+
OriginalDeclaration = originalDeclaration;
1506+
}
1507+
15001508
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
15011509
JVPFunction = decl;
15021510
if (decl && !JVP)

lib/Parse/ParseDecl.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3473,6 +3473,18 @@ void Parser::setLocalDiscriminatorToParamList(ParameterList *PL) {
34733473
}
34743474
}
34753475

3476+
/// Set the original declaration in `@differentiable` attributes.
3477+
///
3478+
/// Necessary because `Parser::parseNewDeclAttribute` (which calls
3479+
/// `Parser::parseDifferentiableAttribute`) does not have access to the
3480+
/// parent declaration of parsed attributes.
3481+
static void
3482+
setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
3483+
Decl *D) {
3484+
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
3485+
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
3486+
}
3487+
34763488
/// Parse a single syntactic declaration and return a list of decl
34773489
/// ASTs. This can return multiple results for var decls that bind to multiple
34783490
/// values, structs that define a struct decl and a constructor, etc.
@@ -3873,7 +3885,8 @@ Parser::parseDecl(ParseDeclOptions Flags,
38733885
if (DeclResult.isNonNull()) {
38743886
Decl *D = DeclResult.get();
38753887
if (!declWasHandledAlready(D))
3876-
Handler(DeclResult.get());
3888+
Handler(D);
3889+
setOriginalDeclarationForDifferentiableAttributes(D->getAttrs(), D);
38773890
}
38783891

38793892
if (!DeclResult.isParseError()) {
@@ -5581,6 +5594,11 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,
55815594

55825595
accessors.record(*this, PrimaryVar, Invalid);
55835596

5597+
// Set original declaration in `@differentiable` attributes.
5598+
for (auto *accessor : accessors.Accessors)
5599+
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
5600+
accessor);
5601+
55845602
return makeParserResult(PrimaryVar);
55855603
}
55865604

@@ -5836,6 +5854,10 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
58365854
VD->setStatic(StaticLoc.isValid());
58375855
VD->getAttrs() = Attributes;
58385856
setLocalDiscriminator(VD);
5857+
5858+
// Set original declaration in `@differentiable` attributes.
5859+
setOriginalDeclarationForDifferentiableAttributes(Attributes, VD);
5860+
58395861
Decls.push_back(VD);
58405862
if (hasOpaqueReturnTy && sf && !InInactiveClauseEnvironment) {
58415863
sf->addUnvalidatedDeclWithOpaqueResultType(VD);
@@ -7083,6 +7105,11 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
70837105

70847106
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));
70857107

7108+
// Set original declaration in `@differentiable` attributes.
7109+
for (auto *accessor : accessors.Accessors)
7110+
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
7111+
accessor);
7112+
70867113
// No need to setLocalDiscriminator because subscripts cannot
70877114
// validly appear outside of type decls.
70887115
return makeParserResult(Status, Subscript);

0 commit comments

Comments
 (0)