19
19
#include " TypeCheckType.h"
20
20
#include " TypeChecker.h"
21
21
#include " swift/AST/ASTVisitor.h"
22
+ #include " swift/AST/ASTWalker.h"
22
23
#include " swift/AST/ClangModuleLoader.h"
23
24
#include " swift/AST/DiagnosticsParse.h"
24
25
#include " swift/AST/GenericEnvironment.h"
@@ -3592,7 +3593,24 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
3592
3593
}
3593
3594
3594
3595
// SWIFT_ENABLE_TENSORFLOW
3595
- void AttributeChecker::visitDerivativeAttr (DerivativeAttr *attr) {
3596
+ // / Typechecks the given derivative attribute `attr` on decl `D`.
3597
+ // /
3598
+ // / Effects are:
3599
+ // / - Sets the parameter indices on `attr`.
3600
+ // / - Diagnoses errors.
3601
+ // / - Stores the attribute in the `ASTContext` list of derivative attributes.
3602
+ // / - Stores the derivative configuration in the original function's list of
3603
+ // / derivative configurations.
3604
+ // /
3605
+ // / \returns true on error, false on success.
3606
+ static bool typeCheckDerivativeAttr (ASTContext &Ctx, Decl *D,
3607
+ DerivativeAttr *attr) {
3608
+
3609
+ // Note: Implementation must be idempotent because it can get called multiple
3610
+ // times for the same attribute.
3611
+
3612
+ auto &diags = Ctx.Diags ;
3613
+
3596
3614
FuncDecl *derivative = cast<FuncDecl>(D);
3597
3615
auto lookupConformance =
3598
3616
LookUpConformanceInModule (D->getDeclContext ()->getParentModule ());
@@ -3612,29 +3630,27 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3612
3630
auto derivativeResultTupleType = derivativeResultType->getAs <TupleType>();
3613
3631
if (!derivativeResultTupleType ||
3614
3632
derivativeResultTupleType->getNumElements () != 2 ) {
3615
- diagnose (attr->getLocation (), diag::derivative_attr_expected_result_tuple);
3616
- attr-> setInvalid ( );
3617
- return ;
3633
+ diags. diagnose (attr->getLocation (),
3634
+ diag::derivative_attr_expected_result_tuple );
3635
+ return true ;
3618
3636
}
3619
3637
auto valueResultElt = derivativeResultTupleType->getElement (0 );
3620
3638
auto funcResultElt = derivativeResultTupleType->getElement (1 );
3621
3639
// Get derivative kind and derivative function identifier.
3622
3640
AutoDiffDerivativeFunctionKind kind;
3623
3641
if (valueResultElt.getName ().str () != " value" ) {
3624
- diagnose (attr->getLocation (),
3625
- diag::derivative_attr_invalid_result_tuple_value_label);
3626
- attr->setInvalid ();
3627
- return ;
3642
+ diags.diagnose (attr->getLocation (),
3643
+ diag::derivative_attr_invalid_result_tuple_value_label);
3644
+ return true ;
3628
3645
}
3629
3646
if (funcResultElt.getName ().str () == " differential" ) {
3630
3647
kind = AutoDiffDerivativeFunctionKind::JVP;
3631
3648
} else if (funcResultElt.getName ().str () == " pullback" ) {
3632
3649
kind = AutoDiffDerivativeFunctionKind::VJP;
3633
3650
} else {
3634
- diagnose (attr->getLocation (),
3635
- diag::derivative_attr_invalid_result_tuple_func_label);
3636
- attr->setInvalid ();
3637
- return ;
3651
+ diags.diagnose (attr->getLocation (),
3652
+ diag::derivative_attr_invalid_result_tuple_func_label);
3653
+ return true ;
3638
3654
}
3639
3655
attr->setDerivativeKind (kind);
3640
3656
// `value: R` result tuple element must conform to `Differentiable`.
@@ -3645,11 +3661,10 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3645
3661
auto valueResultConf = TypeChecker::conformsToProtocol (
3646
3662
valueResultType, diffableProto, derivative->getDeclContext (), None);
3647
3663
if (!valueResultConf) {
3648
- diagnose (attr->getLocation (),
3649
- diag::derivative_attr_result_value_not_differentiable,
3650
- valueResultElt.getType ());
3651
- attr->setInvalid ();
3652
- return ;
3664
+ diags.diagnose (attr->getLocation (),
3665
+ diag::derivative_attr_result_value_not_differentiable,
3666
+ valueResultElt.getType ());
3667
+ return true ;
3653
3668
}
3654
3669
3655
3670
// Compute expected original function type and look up original function.
@@ -3693,23 +3708,23 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3693
3708
};
3694
3709
3695
3710
auto noneValidDiagnostic = [&]() {
3696
- diagnose (originalName.Loc ,
3697
- diag::autodiff_attr_original_decl_none_valid_found,
3698
- originalName.Name , originalFnType);
3711
+ diags. diagnose (originalName.Loc ,
3712
+ diag::autodiff_attr_original_decl_none_valid_found,
3713
+ originalName.Name , originalFnType);
3699
3714
};
3700
3715
auto ambiguousDiagnostic = [&]() {
3701
- diagnose (originalName.Loc , diag::attr_ambiguous_reference_to_decl,
3702
- originalName.Name , attr->getAttrName ());
3716
+ diags. diagnose (originalName.Loc , diag::attr_ambiguous_reference_to_decl,
3717
+ originalName.Name , attr->getAttrName ());
3703
3718
};
3704
3719
auto notFunctionDiagnostic = [&]() {
3705
- diagnose (originalName.Loc ,
3706
- diag::autodiff_attr_original_decl_invalid_kind,
3707
- originalName.Name );
3720
+ diags. diagnose (originalName.Loc ,
3721
+ diag::autodiff_attr_original_decl_invalid_kind,
3722
+ originalName.Name );
3708
3723
};
3709
3724
std::function<void ()> invalidTypeContextDiagnostic = [&]() {
3710
- diagnose (originalName.Loc ,
3711
- diag::autodiff_attr_original_decl_not_same_type_context,
3712
- originalName.Name );
3725
+ diags. diagnose (originalName.Loc ,
3726
+ diag::autodiff_attr_original_decl_not_same_type_context,
3727
+ originalName.Name );
3713
3728
};
3714
3729
3715
3730
// Returns true if the derivative function and original function candidate are
@@ -3743,52 +3758,39 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3743
3758
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
3744
3759
hasValidTypeContext, invalidTypeContextDiagnostic);
3745
3760
if (!originalAFD) {
3746
- attr->setInvalid ();
3747
- return ;
3761
+ return true ;
3748
3762
}
3749
3763
// Diagnose original stored properties. Stored properties cannot have custom
3750
3764
// registered derivatives.
3751
3765
if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
3752
3766
auto *asd = accessorDecl->getStorage ();
3753
3767
if (asd->hasStorage ()) {
3754
- diagnose (originalName.Loc ,
3755
- diag::derivative_attr_original_stored_property_unsupported,
3756
- originalName.Name );
3757
- diagnose (originalAFD->getLoc (), diag::decl_declared_here,
3758
- asd->getFullName ());
3759
- attr->setInvalid ();
3760
- return ;
3768
+ diags.diagnose (originalName.Loc ,
3769
+ diag::derivative_attr_original_stored_property_unsupported,
3770
+ originalName.Name );
3771
+ diags.diagnose (originalAFD->getLoc (), diag::decl_declared_here,
3772
+ asd->getFullName ());
3773
+ return true ;
3761
3774
}
3762
3775
}
3763
3776
attr->setOriginalFunction (originalAFD);
3764
3777
3765
- // Get checked wrt param indices.
3766
- auto *checkedWrtParamIndices = attr->getParameterIndices ();
3767
-
3768
3778
// Get the parsed wrt param indices, which have not yet been checked.
3769
3779
// This is defined for parsed attributes.
3770
3780
auto parsedWrtParams = attr->getParsedParameters ();
3771
3781
3772
- // If checked wrt param indices are not specified, compute them.
3782
+ auto *checkedWrtParamIndices = computeDifferentiationParameters (
3783
+ parsedWrtParams, derivative, derivative->getGenericEnvironment (),
3784
+ attr->getAttrName (), attr->getLocation ());
3773
3785
if (!checkedWrtParamIndices)
3774
- checkedWrtParamIndices =
3775
- computeDifferentiationParameters (parsedWrtParams, derivative,
3776
- derivative->getGenericEnvironment (),
3777
- attr->getAttrName (),
3778
- attr->getLocation ());
3779
- if (!checkedWrtParamIndices) {
3780
- attr->setInvalid ();
3781
- return ;
3782
- }
3786
+ return true ;
3783
3787
3784
3788
// Check if differentiation parameter indices are valid.
3785
3789
if (checkDifferentiationParameters (
3786
3790
originalAFD, checkedWrtParamIndices, originalFnType,
3787
3791
derivative->getGenericEnvironment (), derivative->getModuleContext (),
3788
- parsedWrtParams, attr->getLocation ())) {
3789
- attr->setInvalid ();
3790
- return ;
3791
- }
3792
+ parsedWrtParams, attr->getLocation ()))
3793
+ return true ;
3792
3794
3793
3795
// Set the checked differentiation parameter indices in the attribute.
3794
3796
attr->setParameterIndices (checkedWrtParamIndices);
@@ -3846,25 +3848,25 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3846
3848
// Check if differential/pullback type matches expected type.
3847
3849
if (!actualFuncEltType->isEqual (expectedFuncEltType)) {
3848
3850
// Emit differential/pullback type mismatch error on attribute.
3849
- diagnose (attr->getLocation (),
3850
- diag::derivative_attr_result_func_type_mismatch,
3851
- funcResultElt.getName (), originalAFD->getFullName ());
3851
+ diags. diagnose (attr->getLocation (),
3852
+ diag::derivative_attr_result_func_type_mismatch,
3853
+ funcResultElt.getName (), originalAFD->getFullName ());
3852
3854
// Emit note with expected differential/pullback type on actual type
3853
3855
// location.
3854
3856
auto *tupleReturnTypeRepr =
3855
3857
cast<TupleTypeRepr>(derivative->getBodyResultTypeLoc ().getTypeRepr ());
3856
3858
auto *funcEltTypeRepr = tupleReturnTypeRepr->getElementType (1 );
3857
- diagnose (funcEltTypeRepr->getStartLoc (),
3858
- diag::derivative_attr_result_func_type_mismatch_note,
3859
- funcResultElt.getName (), expectedFuncEltType)
3859
+ diags
3860
+ .diagnose (funcEltTypeRepr->getStartLoc (),
3861
+ diag::derivative_attr_result_func_type_mismatch_note,
3862
+ funcResultElt.getName (), expectedFuncEltType)
3860
3863
.highlight (funcEltTypeRepr->getSourceRange ());
3861
3864
// Emit note showing original function location, if possible.
3862
3865
if (originalAFD->getLoc ().isValid ())
3863
- diagnose (originalAFD->getLoc (),
3864
- diag::derivative_attr_result_func_original_note,
3865
- originalAFD->getFullName ());
3866
- attr->setInvalid ();
3867
- return ;
3866
+ diags.diagnose (originalAFD->getLoc (),
3867
+ diag::derivative_attr_result_func_original_note,
3868
+ originalAFD->getFullName ());
3869
+ return true ;
3868
3870
}
3869
3871
3870
3872
// Check that derivative visibility is at least as restricted as original
@@ -3873,29 +3875,42 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3873
3875
originalAFD->getFormalAccessScope () &&
3874
3876
!derivative->getFormalAccessScope ().isChildOf (
3875
3877
originalAFD->getFormalAccessScope ())) {
3876
- diagnoseAndRemoveAttr (attr, diag::derivative_attr_visibility_too_broad);
3877
- diagnose (originalAFD->getLoc (),
3878
- diag::derivative_attr_visibility_too_broad_note);
3879
- return ;
3878
+ diags.diagnose (attr->getLocation (),
3879
+ diag::derivative_attr_visibility_too_broad);
3880
+ diags.diagnose (originalAFD->getLoc (),
3881
+ diag::derivative_attr_visibility_too_broad_note);
3882
+ return true ;
3880
3883
}
3881
3884
3882
3885
// Reject duplicate `@derivative` attributes.
3883
- auto insertion = Ctx.DerivativeAttrs .try_emplace (
3884
- {originalAFD, checkedWrtParamIndices, kind}, attr);
3885
- if (!insertion.second ) {
3886
- diagnoseAndRemoveAttr (attr,
3887
- diag::derivative_attr_original_already_has_derivative,
3888
- originalAFD->getFullName ());
3889
- diagnose (insertion.first ->getSecond ()->getLocation (),
3890
- diag::differentiable_attr_duplicate_note);
3891
- return ;
3886
+ auto &derivativeAttrs =
3887
+ Ctx.DerivativeAttrs [{originalAFD, checkedWrtParamIndices, kind}];
3888
+ derivativeAttrs.insert (attr);
3889
+ if (derivativeAttrs.size () > 1 ) {
3890
+ diags.diagnose (attr->getLocation (),
3891
+ diag::derivative_attr_original_already_has_derivative,
3892
+ originalAFD->getFullName ());
3893
+ for (auto *duplicateAttr : derivativeAttrs) {
3894
+ if (duplicateAttr == attr)
3895
+ continue ;
3896
+ diags.diagnose (duplicateAttr->getLocation (),
3897
+ diag::differentiable_attr_duplicate_note);
3898
+ }
3899
+ return true ;
3892
3900
}
3893
3901
3894
3902
// Register derivative function configuration.
3895
3903
auto *resultIndices = IndexSubset::get (Ctx, 1 , {0 });
3896
3904
originalAFD->addDerivativeFunctionConfiguration (
3897
3905
{checkedWrtParamIndices, resultIndices,
3898
3906
derivative->getGenericSignature ()});
3907
+
3908
+ return false ;
3909
+ }
3910
+
3911
+ void AttributeChecker::visitDerivativeAttr (DerivativeAttr *attr) {
3912
+ if (typeCheckDerivativeAttr (Ctx, D, attr))
3913
+ attr->setInvalid ();
3899
3914
}
3900
3915
3901
3916
void AttributeChecker::visitTransposeAttr (TransposeAttr *attr) {
@@ -4527,3 +4542,23 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
4527
4542
4528
4543
return nullptr ;
4529
4544
}
4545
+
4546
+ // SWIFT_ENABLE_TENSORFLOW
4547
+ void TypeChecker::typeCheckDerivativeAttrs (SourceFile &sourceFile) {
4548
+ class DerivativeAttrBindingWalker : public ASTWalker {
4549
+ bool walkToDeclPre (Decl *decl) override {
4550
+ auto f = dyn_cast<AbstractFunctionDecl>(decl);
4551
+ if (!f)
4552
+ return true ;
4553
+ for (auto *attr : f->getAttrs ())
4554
+ if (auto *da = dyn_cast<DerivativeAttr>(attr))
4555
+ typeCheckDerivativeAttr (f->getASTContext (), f, da);
4556
+ return true ;
4557
+ }
4558
+ };
4559
+
4560
+ DiagnosticTransaction diagTxn (sourceFile.getASTContext ().Diags );
4561
+ DerivativeAttrBindingWalker walker;
4562
+ sourceFile.getParentModule ()->walk (walker);
4563
+ diagTxn.abort ();
4564
+ }
0 commit comments