Skip to content

[AutoDiff upstream] Store @differentiable original declaration. #29082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,10 @@ class DifferentiableAttr final
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The declaration on which the `@differentiable` attribute is declared.
/// May not be a valid declaration for `@differentiable` attributes.
/// Resolved during parsing and deserialization.
Decl *OriginalDeclaration = nullptr;
/// Whether this function is linear (optional).
bool Linear;
/// The number of parsed differentiability parameters specified in 'wrt:'.
Expand Down Expand Up @@ -1703,6 +1707,12 @@ class DifferentiableAttr final
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig);

Decl *getOriginalDeclaration() const { return OriginalDeclaration; }

/// Sets the original declaration on which this attribute is declared.
/// Should only be used by parsing and deserialization.
void setOriginalDeclaration(Decl *originalDeclaration);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Expand Down Expand Up @@ -1756,8 +1766,7 @@ class DifferentiableAttr final
// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
void print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause = false,
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false,
bool omitDerivativeFunctions = false) const;

static bool classof(const DeclAttribute *DA) {
Expand Down
10 changes: 9 additions & 1 deletion lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,8 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
OriginalDeclaration(original), Linear(linear), JVP(std::move(jvp)),
VJP(std::move(vjp)) {
setParameterIndices(parameterIndices);
setDerivativeGenericSignature(derivativeGenSig);
}
Expand Down Expand Up @@ -1497,6 +1498,13 @@ DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
std::move(vjp), derivativeGenSig);
}

void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
assert(originalDeclaration && "Original declaration must be non-null");
assert(!OriginalDeclaration &&
"Original declaration cannot have already been set");
OriginalDeclaration = originalDeclaration;
}

void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
JVPFunction = decl;
if (decl && !JVP)
Expand Down
29 changes: 28 additions & 1 deletion lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3473,6 +3473,18 @@ void Parser::setLocalDiscriminatorToParamList(ParameterList *PL) {
}
}

/// Set the original declaration in `@differentiable` attributes.
///
/// Necessary because `Parser::parseNewDeclAttribute` (which calls
/// `Parser::parseDifferentiableAttribute`) does not have access to the
/// parent declaration of parsed attributes.
static void
setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
Decl *D) {
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
}

/// Parse a single syntactic declaration and return a list of decl
/// ASTs. This can return multiple results for var decls that bind to multiple
/// values, structs that define a struct decl and a constructor, etc.
Expand Down Expand Up @@ -3873,7 +3885,8 @@ Parser::parseDecl(ParseDeclOptions Flags,
if (DeclResult.isNonNull()) {
Decl *D = DeclResult.get();
if (!declWasHandledAlready(D))
Handler(DeclResult.get());
Handler(D);
setOriginalDeclarationForDifferentiableAttributes(D->getAttrs(), D);
}

if (!DeclResult.isParseError()) {
Expand Down Expand Up @@ -5581,6 +5594,11 @@ Parser::parseDeclVarGetSet(Pattern *pattern, ParseDeclOptions Flags,

accessors.record(*this, PrimaryVar, Invalid);

// Set original declaration in `@differentiable` attributes.
for (auto *accessor : accessors.Accessors)
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
accessor);

return makeParserResult(PrimaryVar);
}

Expand Down Expand Up @@ -5836,6 +5854,10 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
VD->setStatic(StaticLoc.isValid());
VD->getAttrs() = Attributes;
setLocalDiscriminator(VD);

// Set original declaration in `@differentiable` attributes.
setOriginalDeclarationForDifferentiableAttributes(Attributes, VD);

Decls.push_back(VD);
if (hasOpaqueReturnTy && sf && !InInactiveClauseEnvironment) {
sf->addUnvalidatedDeclWithOpaqueResultType(VD);
Expand Down Expand Up @@ -7083,6 +7105,11 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,

accessors.record(*this, Subscript, (Invalid || !Status.isSuccess()));

// Set original declaration in `@differentiable` attributes.
for (auto *accessor : accessors.Accessors)
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
accessor);

// No need to setLocalDiscriminator because subscripts cannot
// validly appear outside of type decls.
return makeParserResult(Status, Subscript);
Expand Down