30
30
#include " swift/AST/ParameterList.h"
31
31
#include " swift/AST/PropertyWrappers.h"
32
32
#include " swift/AST/SourceFile.h"
33
+ #include " swift/AST/StorageImpl.h"
33
34
#include " swift/AST/TypeCheckRequests.h"
34
35
#include " swift/AST/Types.h"
35
36
#include " swift/Parse/Lexer.h"
@@ -3609,12 +3610,14 @@ static IndexSubset *computeDifferentiabilityParameters(
3609
3610
// If the function declaration cannot be resolved, emits a diagnostic and
3610
3611
// returns nullptr.
3611
3612
static AbstractFunctionDecl *findAbstractFunctionDecl (
3612
- DeclNameRef funcName, SourceLoc funcNameLoc, Type baseType,
3613
+ DeclNameRef funcName, SourceLoc funcNameLoc,
3614
+ Optional<AccessorKind> accessorKind, Type baseType,
3613
3615
DeclContext *lookupContext,
3614
3616
const std::function<bool (AbstractFunctionDecl *)> &isValidCandidate,
3615
3617
const std::function<void()> &noneValidDiagnostic,
3616
3618
const std::function<void()> &ambiguousDiagnostic,
3617
3619
const std::function<void()> ¬FunctionDiagnostic,
3620
+ const std::function<void()> &missingAccessorDiagnostic,
3618
3621
NameLookupOptions lookupOptions,
3619
3622
const Optional<std::function<bool(AbstractFunctionDecl *)>>
3620
3623
&hasValidTypeCtx,
@@ -3640,6 +3643,7 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
3640
3643
bool wrongTypeContext = false ;
3641
3644
bool ambiguousFuncDecl = false ;
3642
3645
bool foundInvalid = false ;
3646
+ bool missingAccessor = false ;
3643
3647
3644
3648
// Filter lookup results.
3645
3649
for (auto choice : results) {
@@ -3648,10 +3652,21 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
3648
3652
continue ;
3649
3653
// Cast the candidate to an `AbstractFunctionDecl`.
3650
3654
auto *candidate = dyn_cast<AbstractFunctionDecl>(decl);
3651
- // If the candidate is an `AbstractStorageDecl`, use its getter as the
3652
- // candidate.
3653
- if (auto *asd = dyn_cast<AbstractStorageDecl>(decl))
3654
- candidate = asd->getOpaqueAccessor (AccessorKind::Get);
3655
+ // If the candidate is an `AbstractStorageDecl`, use one of its accessors as
3656
+ // the candidate.
3657
+ if (auto *asd = dyn_cast<AbstractStorageDecl>(decl)) {
3658
+ // If accessor kind is specified, use corresponding accessor from the
3659
+ // candidate. Otherwise, use the getter by default.
3660
+ if (accessorKind != None) {
3661
+ candidate = asd->getOpaqueAccessor (accessorKind.getValue ());
3662
+ // Error if candidate is missing the requested accessor.
3663
+ if (!candidate)
3664
+ missingAccessor = true ;
3665
+ } else
3666
+ candidate = asd->getOpaqueAccessor (AccessorKind::Get);
3667
+ } else if (accessorKind != None) {
3668
+ missingAccessor = true ;
3669
+ }
3655
3670
if (!candidate) {
3656
3671
notFunction = true ;
3657
3672
continue ;
@@ -3671,8 +3686,9 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
3671
3686
}
3672
3687
resolvedCandidate = candidate;
3673
3688
}
3689
+
3674
3690
// If function declaration was resolved, return it.
3675
- if (resolvedCandidate)
3691
+ if (resolvedCandidate && !missingAccessor )
3676
3692
return resolvedCandidate;
3677
3693
3678
3694
// Otherwise, emit the appropriate diagnostic and return nullptr.
@@ -3685,6 +3701,10 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
3685
3701
ambiguousDiagnostic ();
3686
3702
return nullptr ;
3687
3703
}
3704
+ if (missingAccessor) {
3705
+ missingAccessorDiagnostic ();
3706
+ return nullptr ;
3707
+ }
3688
3708
if (wrongTypeContext) {
3689
3709
assert (invalidTypeCtxDiagnostic &&
3690
3710
" Type context diagnostic should've been specified" );
@@ -4429,6 +4449,13 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4429
4449
diag::autodiff_attr_original_decl_invalid_kind,
4430
4450
originalName.Name );
4431
4451
};
4452
+ auto missingAccessorDiagnostic = [&]() {
4453
+ auto accessorKind = originalName.AccessorKind .getValueOr (AccessorKind::Get);
4454
+ auto accessorLabel = getAccessorLabel (accessorKind);
4455
+ diags.diagnose (originalName.Loc , diag::autodiff_attr_accessor_not_found,
4456
+ originalName.Name , accessorLabel);
4457
+ };
4458
+
4432
4459
std::function<void ()> invalidTypeContextDiagnostic = [&]() {
4433
4460
diags.diagnose (originalName.Loc ,
4434
4461
diag::autodiff_attr_original_decl_not_same_type_context,
@@ -4473,15 +4500,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4473
4500
4474
4501
// Look up original function.
4475
4502
auto *originalAFD = findAbstractFunctionDecl (
4476
- originalName.Name , originalName.Loc .getBaseNameLoc (), baseType,
4477
- derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
4478
- ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
4479
- hasValidTypeContext, invalidTypeContextDiagnostic);
4503
+ originalName.Name , originalName.Loc .getBaseNameLoc (),
4504
+ originalName.AccessorKind , baseType, derivativeTypeCtx, isValidOriginal,
4505
+ noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
4506
+ missingAccessorDiagnostic, lookupOptions, hasValidTypeContext,
4507
+ invalidTypeContextDiagnostic);
4480
4508
if (!originalAFD)
4481
4509
return true ;
4482
- // Diagnose original stored properties. Stored properties cannot have custom
4483
- // registered derivatives.
4510
+
4484
4511
if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
4512
+ // Diagnose original stored properties. Stored properties cannot have custom
4513
+ // registered derivatives.
4485
4514
auto *asd = accessorDecl->getStorage ();
4486
4515
if (asd->hasStorage ()) {
4487
4516
diags.diagnose (originalName.Loc ,
@@ -4491,6 +4520,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4491
4520
asd->getName ());
4492
4521
return true ;
4493
4522
}
4523
+ // Diagnose original class property and subscript setters.
4524
+ // TODO(SR-13096): Fix derivative function typing results regarding
4525
+ // class-typed function parameters.
4526
+ if (asd->getDeclContext ()->getSelfClassDecl () &&
4527
+ accessorDecl->getAccessorKind () == AccessorKind::Set) {
4528
+ diags.diagnose (originalName.Loc ,
4529
+ diag::derivative_attr_class_setter_unsupported);
4530
+ diags.diagnose (originalAFD->getLoc (), diag::decl_declared_here,
4531
+ asd->getName ());
4532
+ return true ;
4533
+ }
4494
4534
}
4495
4535
// Diagnose if original function is an invalid class member.
4496
4536
bool isOriginalClassMember =
@@ -4998,6 +5038,13 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
4998
5038
diag::autodiff_attr_original_decl_invalid_kind,
4999
5039
originalName.Name );
5000
5040
};
5041
+ auto missingAccessorDiagnostic = [&]() {
5042
+ auto accessorKind = originalName.AccessorKind .getValueOr (AccessorKind::Get);
5043
+ auto accessorLabel = getAccessorLabel (accessorKind);
5044
+ diagnose (originalName.Loc , diag::autodiff_attr_accessor_not_found,
5045
+ originalName.Name , accessorLabel);
5046
+ };
5047
+
5001
5048
std::function<void ()> invalidTypeContextDiagnostic = [&]() {
5002
5049
diagnose (originalName.Loc ,
5003
5050
diag::autodiff_attr_original_decl_not_same_type_context,
@@ -5028,8 +5075,9 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
5028
5075
if (attr->getBaseTypeRepr ())
5029
5076
funcLoc = attr->getBaseTypeRepr ()->getLoc ();
5030
5077
auto *originalAFD = findAbstractFunctionDecl (
5031
- originalName.Name , funcLoc, baseType, transposeTypeCtx, isValidOriginal,
5032
- noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
5078
+ originalName.Name , funcLoc, originalName.AccessorKind , baseType,
5079
+ transposeTypeCtx, isValidOriginal, noneValidDiagnostic,
5080
+ ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic,
5033
5081
lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic);
5034
5082
if (!originalAFD) {
5035
5083
attr->setInvalid ();
0 commit comments