Skip to content

[Autodiff] Derivative Registration for the Get and Set Accessors #32614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "swift/AST/Ownership.h"
#include "swift/AST/PlatformKind.h"
#include "swift/AST/Requirement.h"
#include "swift/AST/StorageImpl.h"
#include "swift/AST/TrailingCallArguments.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
Expand Down Expand Up @@ -1718,6 +1719,7 @@ class OriginallyDefinedInAttr: public DeclAttribute {
struct DeclNameRefWithLoc {
DeclNameRef Name;
DeclNameLoc Loc;
Optional<AccessorKind> AccessorKind;
};

/// Attribute that marks a function as differentiable.
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3095,6 +3095,10 @@ ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none,
ERROR(derivative_attr_nonfinal_class_init_unsupported,none,
"cannot register derivative for 'init' in a non-final class; consider "
"making %0 final", (Type))
// TODO(SR-13096): Remove this temporary diagnostic.
ERROR(derivative_attr_class_setter_unsupported,none,
"cannot yet register derivative for class property or subscript setters",
())
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))
NOTE(derivative_attr_duplicate_note,none,
Expand Down Expand Up @@ -3133,6 +3137,8 @@ NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
ERROR(autodiff_attr_original_decl_invalid_kind,none,
"%0 is not a 'func', 'init', 'subscript', or 'var' computed property "
"declaration", (DeclNameRef))
ERROR(autodiff_attr_accessor_not_found,none,
"%0 does not have a '%1' accessor", (DeclNameRef, StringRef))
ERROR(autodiff_attr_original_decl_none_valid_found,none,
"could not find function %0 with expected type %1", (DeclNameRef, Type))
ERROR(autodiff_attr_original_decl_not_same_type_context,none,
Expand Down
95 changes: 81 additions & 14 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,11 +1037,24 @@ bool Parser::parseDifferentiableAttributeArguments(
return false;
}

// Helper function that returns the accessor kind if a token is an accessor
// label.
static Optional<AccessorKind> isAccessorLabel(const Token &token) {
if (token.is(tok::identifier)) {
StringRef tokText = token.getText();
for (auto accessor : allAccessorKinds())
if (tokText == getAccessorLabel(accessor))
return accessor;
}
return None;
}

/// Helper function that parses 'type-identifier' for `parseQualifiedDeclName`.
/// Returns true on error. Sets `baseType` to the parsed base type if present,
/// or to `nullptr` if not. A missing base type is not considered an error.
static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
baseType = nullptr;
Parser::BacktrackingScope backtrack(P);

// If base type cannot be parsed, return false (no error).
if (!P.canParseBaseTypeForQualifiedDeclName())
Expand All @@ -1057,6 +1070,28 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
// `parseTypeIdentifier(/*isParsingQualifiedDeclName*/ true)` leaves the
// leading period unparsed to avoid syntax verification errors.
assert(P.startsWithSymbol(P.Tok, '.') && "false");

// Check if this is a reference to a property or subscript accessor.
//
// Note: There is an parsing ambiguity here. An accessor label identifier
// (e.g. "set") may refer to the final declaration name component instead of
// an accessor kind.
//
// FIXME: It is wrong to backtrack parsing the entire base type if an accessor
// label is found. Instead, only the final component of the base type should
// be backtracked. It may be best to implement this in
// `Parser::parseTypeIdentifier`.
//
// Example: consider parsing `A.B.property.set`.
// Current behavior: base type is entirely backtracked.
// Ideal behavior: base type is parsed as `A.B`.
if (P.Tok.is(tok::period)) {
const Token &nextToken = P.peekToken();
if (isAccessorLabel(nextToken).hasValue())
return false;
}

backtrack.cancelBacktrack();
P.consumeStartingCharacterOfCurrentToken(tok::period);

// Set base type and return false (no error).
Expand All @@ -1079,20 +1114,52 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
TypeRepr *&baseType,
DeclNameRefWithLoc &original) {
SyntaxParsingContext DeclNameContext(P.SyntaxContext,
SyntaxKind::QualifiedDeclName);
// Parse base type.
if (parseBaseTypeForQualifiedDeclName(P, baseType))
return true;
// Parse final declaration name.
original.Name = P.parseDeclNameRef(
original.Loc, nameParseError,
Parser::DeclNameFlag::AllowZeroArgCompoundNames |
Parser::DeclNameFlag::AllowKeywordsUsingSpecialNames |
Parser::DeclNameFlag::AllowOperators);
// The base type is optional, but the final unqualified declaration name is
// not. If name could not be parsed, return true for error.
return !original.Name;
{
SyntaxParsingContext DeclNameContext(P.SyntaxContext,
SyntaxKind::QualifiedDeclName);
// Parse base type.
if (parseBaseTypeForQualifiedDeclName(P, baseType))
return true;
// Parse final declaration name.
original.Name = P.parseDeclNameRef(
original.Loc, nameParseError,
Parser::DeclNameFlag::AllowZeroArgCompoundNames |
Parser::DeclNameFlag::AllowKeywordsUsingSpecialNames |
Parser::DeclNameFlag::AllowOperators);
// The base type is optional, but the final unqualified declaration name is
// not. If name could not be parsed, return true for error.
if (!original.Name)
return true;
}

// Parse an optional accessor kind.
//
// Note: there is an parsing ambiguity here.
//
// Example: `A.B.property.set` may be parsed as one of the following:
//
// 1. No accessor kind.
// - Base type: `A.B.property`
// - Declaration name: `set`
// - Accessor kind: <none>
//
// 2. Accessor kind exists.
// - Base type: `A.B`
// - Declaration name: `property`
// - Accessor kind: `set`
//
// Currently, we follow (2) because it's more useful in practice.
if (P.Tok.is(tok::period)) {
const Token &nextToken = P.peekToken();
Optional<AccessorKind> kind = isAccessorLabel(nextToken);
if (kind.hasValue()) {
original.AccessorKind = kind;
P.consumeIf(tok::period);
P.consumeIf(tok::identifier);
}
}

return false;
}

/// Parse a `@derivative(of:)` attribute, returning true on error.
Expand Down
76 changes: 62 additions & 14 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "swift/AST/ParameterList.h"
#include "swift/AST/PropertyWrappers.h"
#include "swift/AST/SourceFile.h"
#include "swift/AST/StorageImpl.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
#include "swift/Parse/Lexer.h"
Expand Down Expand Up @@ -3609,12 +3610,14 @@ static IndexSubset *computeDifferentiabilityParameters(
// If the function declaration cannot be resolved, emits a diagnostic and
// returns nullptr.
static AbstractFunctionDecl *findAbstractFunctionDecl(
DeclNameRef funcName, SourceLoc funcNameLoc, Type baseType,
DeclNameRef funcName, SourceLoc funcNameLoc,
Optional<AccessorKind> accessorKind, Type baseType,
DeclContext *lookupContext,
const std::function<bool(AbstractFunctionDecl *)> &isValidCandidate,
const std::function<void()> &noneValidDiagnostic,
const std::function<void()> &ambiguousDiagnostic,
const std::function<void()> &notFunctionDiagnostic,
const std::function<void()> &missingAccessorDiagnostic,
NameLookupOptions lookupOptions,
const Optional<std::function<bool(AbstractFunctionDecl *)>>
&hasValidTypeCtx,
Expand All @@ -3640,6 +3643,7 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
bool wrongTypeContext = false;
bool ambiguousFuncDecl = false;
bool foundInvalid = false;
bool missingAccessor = false;

// Filter lookup results.
for (auto choice : results) {
Expand All @@ -3648,10 +3652,21 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
continue;
// Cast the candidate to an `AbstractFunctionDecl`.
auto *candidate = dyn_cast<AbstractFunctionDecl>(decl);
// If the candidate is an `AbstractStorageDecl`, use its getter as the
// candidate.
if (auto *asd = dyn_cast<AbstractStorageDecl>(decl))
candidate = asd->getOpaqueAccessor(AccessorKind::Get);
// If the candidate is an `AbstractStorageDecl`, use one of its accessors as
// the candidate.
if (auto *asd = dyn_cast<AbstractStorageDecl>(decl)) {
// If accessor kind is specified, use corresponding accessor from the
// candidate. Otherwise, use the getter by default.
if (accessorKind != None) {
candidate = asd->getOpaqueAccessor(accessorKind.getValue());
// Error if candidate is missing the requested accessor.
if (!candidate)
missingAccessor = true;
} else
candidate = asd->getOpaqueAccessor(AccessorKind::Get);
} else if (accessorKind != None) {
missingAccessor = true;
}
if (!candidate) {
notFunction = true;
continue;
Expand All @@ -3671,8 +3686,9 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
}
resolvedCandidate = candidate;
}

// If function declaration was resolved, return it.
if (resolvedCandidate)
if (resolvedCandidate && !missingAccessor)
return resolvedCandidate;

// Otherwise, emit the appropriate diagnostic and return nullptr.
Expand All @@ -3685,6 +3701,10 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
ambiguousDiagnostic();
return nullptr;
}
if (missingAccessor) {
missingAccessorDiagnostic();
return nullptr;
}
if (wrongTypeContext) {
assert(invalidTypeCtxDiagnostic &&
"Type context diagnostic should've been specified");
Expand Down Expand Up @@ -4429,6 +4449,13 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
diag::autodiff_attr_original_decl_invalid_kind,
originalName.Name);
};
auto missingAccessorDiagnostic = [&]() {
auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get);
auto accessorLabel = getAccessorLabel(accessorKind);
diags.diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found,
originalName.Name, accessorLabel);
};

std::function<void()> invalidTypeContextDiagnostic = [&]() {
diags.diagnose(originalName.Loc,
diag::autodiff_attr_original_decl_not_same_type_context,
Expand Down Expand Up @@ -4473,15 +4500,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,

// Look up original function.
auto *originalAFD = findAbstractFunctionDecl(
originalName.Name, originalName.Loc.getBaseNameLoc(), baseType,
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
hasValidTypeContext, invalidTypeContextDiagnostic);
originalName.Name, originalName.Loc.getBaseNameLoc(),
originalName.AccessorKind, baseType, derivativeTypeCtx, isValidOriginal,
noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
missingAccessorDiagnostic, lookupOptions, hasValidTypeContext,
invalidTypeContextDiagnostic);
if (!originalAFD)
return true;
// Diagnose original stored properties. Stored properties cannot have custom
// registered derivatives.

if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
// Diagnose original stored properties. Stored properties cannot have custom
// registered derivatives.
auto *asd = accessorDecl->getStorage();
if (asd->hasStorage()) {
diags.diagnose(originalName.Loc,
Expand All @@ -4491,6 +4520,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
asd->getName());
return true;
}
// Diagnose original class property and subscript setters.
// TODO(SR-13096): Fix derivative function typing results regarding
// class-typed function parameters.
if (asd->getDeclContext()->getSelfClassDecl() &&
accessorDecl->getAccessorKind() == AccessorKind::Set) {
diags.diagnose(originalName.Loc,
diag::derivative_attr_class_setter_unsupported);
diags.diagnose(originalAFD->getLoc(), diag::decl_declared_here,
asd->getName());
return true;
}
}
// Diagnose if original function is an invalid class member.
bool isOriginalClassMember =
Expand Down Expand Up @@ -4998,6 +5038,13 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
diag::autodiff_attr_original_decl_invalid_kind,
originalName.Name);
};
auto missingAccessorDiagnostic = [&]() {
auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get);
auto accessorLabel = getAccessorLabel(accessorKind);
diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found,
originalName.Name, accessorLabel);
};

std::function<void()> invalidTypeContextDiagnostic = [&]() {
diagnose(originalName.Loc,
diag::autodiff_attr_original_decl_not_same_type_context,
Expand Down Expand Up @@ -5028,8 +5075,9 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
if (attr->getBaseTypeRepr())
funcLoc = attr->getBaseTypeRepr()->getLoc();
auto *originalAFD = findAbstractFunctionDecl(
originalName.Name, funcLoc, baseType, transposeTypeCtx, isValidOriginal,
noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
originalName.Name, funcLoc, originalName.AccessorKind, baseType,
transposeTypeCtx, isValidOriginal, noneValidDiagnostic,
ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic,
lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic);
if (!originalAFD) {
attr->setInvalid();
Expand Down
6 changes: 3 additions & 3 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4388,8 +4388,8 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
parameters);

DeclNameRefWithLoc origName{
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)),
DeclNameLoc(), None};
auto derivativeKind =
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
if (!derivativeKind)
Expand Down Expand Up @@ -4418,7 +4418,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
scratch, isImplicit, origNameId, origDeclId, parameters);

DeclNameRefWithLoc origName{
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc(), None};
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
llvm::SmallBitVector parametersBitVector(parameters.size());
for (unsigned i : indices(parameters))
Expand Down
34 changes: 34 additions & 0 deletions test/AutoDiff/Parse/derivative_attr_parse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

@derivative(of: property.get) // ok
func dPropertyGetter() -> ()

@derivative(of: subscript.get) // ok
func dSubscriptGetter() -> ()

@derivative(of: subscript(_:label:).get) // ok
func dLabeledSubscriptGetter() -> ()

@derivative(of: property.set) // ok
func dPropertySetter() -> ()

@derivative(of: subscript.set) // ok
func dSubscriptSetter() -> ()

@derivative(of: subscript(_:label:).set) // ok
func dLabeledSubscriptSetter() -> ()

@derivative(of: nestedType.name) // ok
func dNestedTypeFunc() -> ()

/// Bad

// expected-error @+2 {{expected an original function name}}
Expand Down Expand Up @@ -98,3 +119,16 @@ func testLocalDerivativeRegistration() {
@derivative(of: sin)
func dsin()
}


func testLocalDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}

// expected-error @+2 {{expected ',' separator}}
// expected-error @+1 {{expected declaration}}
@derivative(of: nestedType.name.set)
func dNestedTypePropertySetter() -> ()

Loading