Skip to content

Commit fbec91a

Browse files
ematejskadan-zheng
andauthored
[Autodiff] Derivative Registration for the Get and Set Accessors (#32614)
* initial changes * Add tests, undo unnecessary changes. * Fixing up computed properties accessors and adding tests for getters. * Adding nested type testcase * Fixing error message for when accessor is referenced but not acutally found. * Cleanup. - Improve diagnostic message. - Clean up code and tests. - Delete unrelated nested type `@derivative` attribute tests. * Temporarily disable class subscript setter derivative registration test. Blocked by SR-13096. * Adding libsyntax integration and fixing up an error message. * Added a helper function for checking if the next token is an accessor label. * Update utils/gyb_syntax_support/AttributeNodes.py Co-authored-by: Dan Zheng <[email protected]> * Update lib/Parse/ParseDecl.cpp Co-authored-by: Dan Zheng <[email protected]> * Add end-to-end derivative registration tests. * NFC: run `git clang-format`. * NFC: clean up formatting. Re-apply `git clang-format`. * Clarify parsing ambiguity FIXME comments. * Adding couple of more testcases and fixing up error message for when accessor is not found on functions resolved. * Update lib/Sema/TypeCheckAttr.cpp Co-authored-by: Dan Zheng <[email protected]> Co-authored-by: Dan Zheng <[email protected]>
1 parent 665eb51 commit fbec91a

File tree

11 files changed

+524
-40
lines changed

11 files changed

+524
-40
lines changed

include/swift/AST/Attr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "swift/AST/Ownership.h"
3636
#include "swift/AST/PlatformKind.h"
3737
#include "swift/AST/Requirement.h"
38+
#include "swift/AST/StorageImpl.h"
3839
#include "swift/AST/TrailingCallArguments.h"
3940
#include "llvm/ADT/SmallVector.h"
4041
#include "llvm/ADT/StringRef.h"
@@ -1718,6 +1719,7 @@ class OriginallyDefinedInAttr: public DeclAttribute {
17181719
struct DeclNameRefWithLoc {
17191720
DeclNameRef Name;
17201721
DeclNameLoc Loc;
1722+
Optional<AccessorKind> AccessorKind;
17211723
};
17221724

17231725
/// Attribute that marks a function as differentiable.

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,6 +3095,10 @@ ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none,
30953095
ERROR(derivative_attr_nonfinal_class_init_unsupported,none,
30963096
"cannot register derivative for 'init' in a non-final class; consider "
30973097
"making %0 final", (Type))
3098+
// TODO(SR-13096): Remove this temporary diagnostic.
3099+
ERROR(derivative_attr_class_setter_unsupported,none,
3100+
"cannot yet register derivative for class property or subscript setters",
3101+
())
30983102
ERROR(derivative_attr_original_already_has_derivative,none,
30993103
"a derivative already exists for %0", (DeclName))
31003104
NOTE(derivative_attr_duplicate_note,none,
@@ -3133,6 +3137,8 @@ NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
31333137
ERROR(autodiff_attr_original_decl_invalid_kind,none,
31343138
"%0 is not a 'func', 'init', 'subscript', or 'var' computed property "
31353139
"declaration", (DeclNameRef))
3140+
ERROR(autodiff_attr_accessor_not_found,none,
3141+
"%0 does not have a '%1' accessor", (DeclNameRef, StringRef))
31363142
ERROR(autodiff_attr_original_decl_none_valid_found,none,
31373143
"could not find function %0 with expected type %1", (DeclNameRef, Type))
31383144
ERROR(autodiff_attr_original_decl_not_same_type_context,none,

lib/Parse/ParseDecl.cpp

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,11 +1037,24 @@ bool Parser::parseDifferentiableAttributeArguments(
10371037
return false;
10381038
}
10391039

1040+
// Helper function that returns the accessor kind if a token is an accessor
1041+
// label.
1042+
static Optional<AccessorKind> isAccessorLabel(const Token &token) {
1043+
if (token.is(tok::identifier)) {
1044+
StringRef tokText = token.getText();
1045+
for (auto accessor : allAccessorKinds())
1046+
if (tokText == getAccessorLabel(accessor))
1047+
return accessor;
1048+
}
1049+
return None;
1050+
}
1051+
10401052
/// Helper function that parses 'type-identifier' for `parseQualifiedDeclName`.
10411053
/// Returns true on error. Sets `baseType` to the parsed base type if present,
10421054
/// or to `nullptr` if not. A missing base type is not considered an error.
10431055
static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
10441056
baseType = nullptr;
1057+
Parser::BacktrackingScope backtrack(P);
10451058

10461059
// If base type cannot be parsed, return false (no error).
10471060
if (!P.canParseBaseTypeForQualifiedDeclName())
@@ -1057,6 +1070,28 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
10571070
// `parseTypeIdentifier(/*isParsingQualifiedDeclName*/ true)` leaves the
10581071
// leading period unparsed to avoid syntax verification errors.
10591072
assert(P.startsWithSymbol(P.Tok, '.') && "false");
1073+
1074+
// Check if this is a reference to a property or subscript accessor.
1075+
//
1076+
// Note: There is an parsing ambiguity here. An accessor label identifier
1077+
// (e.g. "set") may refer to the final declaration name component instead of
1078+
// an accessor kind.
1079+
//
1080+
// FIXME: It is wrong to backtrack parsing the entire base type if an accessor
1081+
// label is found. Instead, only the final component of the base type should
1082+
// be backtracked. It may be best to implement this in
1083+
// `Parser::parseTypeIdentifier`.
1084+
//
1085+
// Example: consider parsing `A.B.property.set`.
1086+
// Current behavior: base type is entirely backtracked.
1087+
// Ideal behavior: base type is parsed as `A.B`.
1088+
if (P.Tok.is(tok::period)) {
1089+
const Token &nextToken = P.peekToken();
1090+
if (isAccessorLabel(nextToken).hasValue())
1091+
return false;
1092+
}
1093+
1094+
backtrack.cancelBacktrack();
10601095
P.consumeStartingCharacterOfCurrentToken(tok::period);
10611096

10621097
// Set base type and return false (no error).
@@ -1079,20 +1114,52 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
10791114
static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
10801115
TypeRepr *&baseType,
10811116
DeclNameRefWithLoc &original) {
1082-
SyntaxParsingContext DeclNameContext(P.SyntaxContext,
1083-
SyntaxKind::QualifiedDeclName);
1084-
// Parse base type.
1085-
if (parseBaseTypeForQualifiedDeclName(P, baseType))
1086-
return true;
1087-
// Parse final declaration name.
1088-
original.Name = P.parseDeclNameRef(
1089-
original.Loc, nameParseError,
1090-
Parser::DeclNameFlag::AllowZeroArgCompoundNames |
1091-
Parser::DeclNameFlag::AllowKeywordsUsingSpecialNames |
1092-
Parser::DeclNameFlag::AllowOperators);
1093-
// The base type is optional, but the final unqualified declaration name is
1094-
// not. If name could not be parsed, return true for error.
1095-
return !original.Name;
1117+
{
1118+
SyntaxParsingContext DeclNameContext(P.SyntaxContext,
1119+
SyntaxKind::QualifiedDeclName);
1120+
// Parse base type.
1121+
if (parseBaseTypeForQualifiedDeclName(P, baseType))
1122+
return true;
1123+
// Parse final declaration name.
1124+
original.Name = P.parseDeclNameRef(
1125+
original.Loc, nameParseError,
1126+
Parser::DeclNameFlag::AllowZeroArgCompoundNames |
1127+
Parser::DeclNameFlag::AllowKeywordsUsingSpecialNames |
1128+
Parser::DeclNameFlag::AllowOperators);
1129+
// The base type is optional, but the final unqualified declaration name is
1130+
// not. If name could not be parsed, return true for error.
1131+
if (!original.Name)
1132+
return true;
1133+
}
1134+
1135+
// Parse an optional accessor kind.
1136+
//
1137+
// Note: there is an parsing ambiguity here.
1138+
//
1139+
// Example: `A.B.property.set` may be parsed as one of the following:
1140+
//
1141+
// 1. No accessor kind.
1142+
// - Base type: `A.B.property`
1143+
// - Declaration name: `set`
1144+
// - Accessor kind: <none>
1145+
//
1146+
// 2. Accessor kind exists.
1147+
// - Base type: `A.B`
1148+
// - Declaration name: `property`
1149+
// - Accessor kind: `set`
1150+
//
1151+
// Currently, we follow (2) because it's more useful in practice.
1152+
if (P.Tok.is(tok::period)) {
1153+
const Token &nextToken = P.peekToken();
1154+
Optional<AccessorKind> kind = isAccessorLabel(nextToken);
1155+
if (kind.hasValue()) {
1156+
original.AccessorKind = kind;
1157+
P.consumeIf(tok::period);
1158+
P.consumeIf(tok::identifier);
1159+
}
1160+
}
1161+
1162+
return false;
10961163
}
10971164

10981165
/// Parse a `@derivative(of:)` attribute, returning true on error.

lib/Sema/TypeCheckAttr.cpp

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "swift/AST/ParameterList.h"
3131
#include "swift/AST/PropertyWrappers.h"
3232
#include "swift/AST/SourceFile.h"
33+
#include "swift/AST/StorageImpl.h"
3334
#include "swift/AST/TypeCheckRequests.h"
3435
#include "swift/AST/Types.h"
3536
#include "swift/Parse/Lexer.h"
@@ -3609,12 +3610,14 @@ static IndexSubset *computeDifferentiabilityParameters(
36093610
// If the function declaration cannot be resolved, emits a diagnostic and
36103611
// returns nullptr.
36113612
static AbstractFunctionDecl *findAbstractFunctionDecl(
3612-
DeclNameRef funcName, SourceLoc funcNameLoc, Type baseType,
3613+
DeclNameRef funcName, SourceLoc funcNameLoc,
3614+
Optional<AccessorKind> accessorKind, Type baseType,
36133615
DeclContext *lookupContext,
36143616
const std::function<bool(AbstractFunctionDecl *)> &isValidCandidate,
36153617
const std::function<void()> &noneValidDiagnostic,
36163618
const std::function<void()> &ambiguousDiagnostic,
36173619
const std::function<void()> &notFunctionDiagnostic,
3620+
const std::function<void()> &missingAccessorDiagnostic,
36183621
NameLookupOptions lookupOptions,
36193622
const Optional<std::function<bool(AbstractFunctionDecl *)>>
36203623
&hasValidTypeCtx,
@@ -3640,6 +3643,7 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36403643
bool wrongTypeContext = false;
36413644
bool ambiguousFuncDecl = false;
36423645
bool foundInvalid = false;
3646+
bool missingAccessor = false;
36433647

36443648
// Filter lookup results.
36453649
for (auto choice : results) {
@@ -3648,10 +3652,21 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36483652
continue;
36493653
// Cast the candidate to an `AbstractFunctionDecl`.
36503654
auto *candidate = dyn_cast<AbstractFunctionDecl>(decl);
3651-
// If the candidate is an `AbstractStorageDecl`, use its getter as the
3652-
// candidate.
3653-
if (auto *asd = dyn_cast<AbstractStorageDecl>(decl))
3654-
candidate = asd->getOpaqueAccessor(AccessorKind::Get);
3655+
// If the candidate is an `AbstractStorageDecl`, use one of its accessors as
3656+
// the candidate.
3657+
if (auto *asd = dyn_cast<AbstractStorageDecl>(decl)) {
3658+
// If accessor kind is specified, use corresponding accessor from the
3659+
// candidate. Otherwise, use the getter by default.
3660+
if (accessorKind != None) {
3661+
candidate = asd->getOpaqueAccessor(accessorKind.getValue());
3662+
// Error if candidate is missing the requested accessor.
3663+
if (!candidate)
3664+
missingAccessor = true;
3665+
} else
3666+
candidate = asd->getOpaqueAccessor(AccessorKind::Get);
3667+
} else if (accessorKind != None) {
3668+
missingAccessor = true;
3669+
}
36553670
if (!candidate) {
36563671
notFunction = true;
36573672
continue;
@@ -3671,8 +3686,9 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36713686
}
36723687
resolvedCandidate = candidate;
36733688
}
3689+
36743690
// If function declaration was resolved, return it.
3675-
if (resolvedCandidate)
3691+
if (resolvedCandidate && !missingAccessor)
36763692
return resolvedCandidate;
36773693

36783694
// Otherwise, emit the appropriate diagnostic and return nullptr.
@@ -3685,6 +3701,10 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36853701
ambiguousDiagnostic();
36863702
return nullptr;
36873703
}
3704+
if (missingAccessor) {
3705+
missingAccessorDiagnostic();
3706+
return nullptr;
3707+
}
36883708
if (wrongTypeContext) {
36893709
assert(invalidTypeCtxDiagnostic &&
36903710
"Type context diagnostic should've been specified");
@@ -4429,6 +4449,13 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44294449
diag::autodiff_attr_original_decl_invalid_kind,
44304450
originalName.Name);
44314451
};
4452+
auto missingAccessorDiagnostic = [&]() {
4453+
auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get);
4454+
auto accessorLabel = getAccessorLabel(accessorKind);
4455+
diags.diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found,
4456+
originalName.Name, accessorLabel);
4457+
};
4458+
44324459
std::function<void()> invalidTypeContextDiagnostic = [&]() {
44334460
diags.diagnose(originalName.Loc,
44344461
diag::autodiff_attr_original_decl_not_same_type_context,
@@ -4473,15 +4500,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44734500

44744501
// Look up original function.
44754502
auto *originalAFD = findAbstractFunctionDecl(
4476-
originalName.Name, originalName.Loc.getBaseNameLoc(), baseType,
4477-
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
4478-
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
4479-
hasValidTypeContext, invalidTypeContextDiagnostic);
4503+
originalName.Name, originalName.Loc.getBaseNameLoc(),
4504+
originalName.AccessorKind, baseType, derivativeTypeCtx, isValidOriginal,
4505+
noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
4506+
missingAccessorDiagnostic, lookupOptions, hasValidTypeContext,
4507+
invalidTypeContextDiagnostic);
44804508
if (!originalAFD)
44814509
return true;
4482-
// Diagnose original stored properties. Stored properties cannot have custom
4483-
// registered derivatives.
4510+
44844511
if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
4512+
// Diagnose original stored properties. Stored properties cannot have custom
4513+
// registered derivatives.
44854514
auto *asd = accessorDecl->getStorage();
44864515
if (asd->hasStorage()) {
44874516
diags.diagnose(originalName.Loc,
@@ -4491,6 +4520,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44914520
asd->getName());
44924521
return true;
44934522
}
4523+
// Diagnose original class property and subscript setters.
4524+
// TODO(SR-13096): Fix derivative function typing results regarding
4525+
// class-typed function parameters.
4526+
if (asd->getDeclContext()->getSelfClassDecl() &&
4527+
accessorDecl->getAccessorKind() == AccessorKind::Set) {
4528+
diags.diagnose(originalName.Loc,
4529+
diag::derivative_attr_class_setter_unsupported);
4530+
diags.diagnose(originalAFD->getLoc(), diag::decl_declared_here,
4531+
asd->getName());
4532+
return true;
4533+
}
44944534
}
44954535
// Diagnose if original function is an invalid class member.
44964536
bool isOriginalClassMember =
@@ -4998,6 +5038,13 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
49985038
diag::autodiff_attr_original_decl_invalid_kind,
49995039
originalName.Name);
50005040
};
5041+
auto missingAccessorDiagnostic = [&]() {
5042+
auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get);
5043+
auto accessorLabel = getAccessorLabel(accessorKind);
5044+
diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found,
5045+
originalName.Name, accessorLabel);
5046+
};
5047+
50015048
std::function<void()> invalidTypeContextDiagnostic = [&]() {
50025049
diagnose(originalName.Loc,
50035050
diag::autodiff_attr_original_decl_not_same_type_context,
@@ -5028,8 +5075,9 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
50285075
if (attr->getBaseTypeRepr())
50295076
funcLoc = attr->getBaseTypeRepr()->getLoc();
50305077
auto *originalAFD = findAbstractFunctionDecl(
5031-
originalName.Name, funcLoc, baseType, transposeTypeCtx, isValidOriginal,
5032-
noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
5078+
originalName.Name, funcLoc, originalName.AccessorKind, baseType,
5079+
transposeTypeCtx, isValidOriginal, noneValidDiagnostic,
5080+
ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic,
50335081
lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic);
50345082
if (!originalAFD) {
50355083
attr->setInvalid();

lib/Serialization/Deserialization.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4388,8 +4388,8 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43884388
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
43894389
parameters);
43904390

4391-
DeclNameRefWithLoc origName{
4392-
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
4391+
DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)),
4392+
DeclNameLoc(), None};
43934393
auto derivativeKind =
43944394
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
43954395
if (!derivativeKind)
@@ -4418,7 +4418,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
44184418
scratch, isImplicit, origNameId, origDeclId, parameters);
44194419

44204420
DeclNameRefWithLoc origName{
4421-
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
4421+
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc(), None};
44224422
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
44234423
llvm::SmallBitVector parametersBitVector(parameters.size());
44244424
for (unsigned i : indices(parameters))

test/AutoDiff/Parse/derivative_attr_parse.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
2626
return (x, { $0 })
2727
}
2828

29+
@derivative(of: property.get) // ok
30+
func dPropertyGetter() -> ()
31+
32+
@derivative(of: subscript.get) // ok
33+
func dSubscriptGetter() -> ()
34+
35+
@derivative(of: subscript(_:label:).get) // ok
36+
func dLabeledSubscriptGetter() -> ()
37+
38+
@derivative(of: property.set) // ok
39+
func dPropertySetter() -> ()
40+
41+
@derivative(of: subscript.set) // ok
42+
func dSubscriptSetter() -> ()
43+
44+
@derivative(of: subscript(_:label:).set) // ok
45+
func dLabeledSubscriptSetter() -> ()
46+
47+
@derivative(of: nestedType.name) // ok
48+
func dNestedTypeFunc() -> ()
49+
2950
/// Bad
3051

3152
// expected-error @+2 {{expected an original function name}}
@@ -98,3 +119,16 @@ func testLocalDerivativeRegistration() {
98119
@derivative(of: sin)
99120
func dsin()
100121
}
122+
123+
124+
func testLocalDerivativeRegistration() {
125+
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
126+
@derivative(of: sin)
127+
func dsin()
128+
}
129+
130+
// expected-error @+2 {{expected ',' separator}}
131+
// expected-error @+1 {{expected declaration}}
132+
@derivative(of: nestedType.name.set)
133+
func dNestedTypePropertySetter() -> ()
134+

0 commit comments

Comments
 (0)