@@ -3391,14 +3391,28 @@ getAutoDiffOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
3391
3391
return originalType;
3392
3392
}
3393
3393
3394
- void AttributeChecker::visitDerivativeAttr (DerivativeAttr *attr) {
3394
+ // / Typechecks the given derivative attribute `attr` on decl `D`.
3395
+ // /
3396
+ // / Effects are:
3397
+ // / - Sets the original function and parameter indices on `attr`.
3398
+ // / - Diagnoses errors.
3399
+ // / - Stores the attribute in `ASTContext::DerivativeAttrs`.
3400
+ // /
3401
+ // / \returns true on error, false on success.
3402
+ static bool typeCheckDerivativeAttr (ASTContext &Ctx, Decl *D,
3403
+ DerivativeAttr *attr) {
3404
+ // Note: Implementation must be idempotent because it can get called multiple
3405
+ // times for the same attribute.
3406
+
3407
+ auto &diags = Ctx.Diags ;
3408
+
3395
3409
// `@derivative` attribute requires experimental differentiable programming
3396
3410
// to be enabled.
3397
3411
auto &ctx = D->getASTContext ();
3398
3412
if (!ctx.LangOpts .EnableExperimentalDifferentiableProgramming ) {
3399
- diagnoseAndRemoveAttr (
3400
- attr, diag::experimental_differentiable_programming_disabled);
3401
- return ;
3413
+ diags. diagnose (attr-> getLocation (),
3414
+ diag::experimental_differentiable_programming_disabled);
3415
+ return true ;
3402
3416
}
3403
3417
auto *derivative = cast<FuncDecl>(D);
3404
3418
auto lookupConformance =
@@ -3418,26 +3432,27 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3418
3432
auto derivativeResultTupleType = derivativeResultType->getAs <TupleType>();
3419
3433
if (!derivativeResultTupleType ||
3420
3434
derivativeResultTupleType->getNumElements () != 2 ) {
3421
- diagnoseAndRemoveAttr (attr, diag::derivative_attr_expected_result_tuple);
3422
- return ;
3435
+ diags.diagnose (attr->getLocation (),
3436
+ diag::derivative_attr_expected_result_tuple);
3437
+ return true ;
3423
3438
}
3424
3439
auto valueResultElt = derivativeResultTupleType->getElement (0 );
3425
3440
auto funcResultElt = derivativeResultTupleType->getElement (1 );
3426
3441
// Get derivative kind and derivative function identifier.
3427
3442
AutoDiffDerivativeFunctionKind kind;
3428
3443
if (valueResultElt.getName ().str () != " value" ) {
3429
- diagnoseAndRemoveAttr (
3430
- attr, diag::derivative_attr_invalid_result_tuple_value_label);
3431
- return ;
3444
+ diags. diagnose (attr-> getLocation (),
3445
+ diag::derivative_attr_invalid_result_tuple_value_label);
3446
+ return true ;
3432
3447
}
3433
3448
if (funcResultElt.getName ().str () == " differential" ) {
3434
3449
kind = AutoDiffDerivativeFunctionKind::JVP;
3435
3450
} else if (funcResultElt.getName ().str () == " pullback" ) {
3436
3451
kind = AutoDiffDerivativeFunctionKind::VJP;
3437
3452
} else {
3438
- diagnoseAndRemoveAttr (
3439
- attr, diag::derivative_attr_invalid_result_tuple_func_label);
3440
- return ;
3453
+ diags. diagnose (attr-> getLocation (),
3454
+ diag::derivative_attr_invalid_result_tuple_func_label);
3455
+ return true ;
3441
3456
}
3442
3457
attr->setDerivativeKind (kind);
3443
3458
// `value: R` result tuple element must conform to `Differentiable`.
@@ -3448,10 +3463,10 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3448
3463
auto valueResultConf = TypeChecker::conformsToProtocol (
3449
3464
valueResultType, diffableProto, derivative->getDeclContext (), None);
3450
3465
if (!valueResultConf) {
3451
- diagnoseAndRemoveAttr (attr,
3452
- diag::derivative_attr_result_value_not_differentiable,
3453
- valueResultElt.getType ());
3454
- return ;
3466
+ diags. diagnose (attr-> getLocation () ,
3467
+ diag::derivative_attr_result_value_not_differentiable,
3468
+ valueResultElt.getType ());
3469
+ return true ;
3455
3470
}
3456
3471
3457
3472
// Compute expected original function type and look up original function.
@@ -3495,22 +3510,23 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3495
3510
};
3496
3511
3497
3512
auto noneValidDiagnostic = [&]() {
3498
- diagnose (originalName.Loc ,
3499
- diag::autodiff_attr_original_decl_none_valid_found,
3500
- originalName.Name , originalFnType);
3513
+ diags. diagnose (originalName.Loc ,
3514
+ diag::autodiff_attr_original_decl_none_valid_found,
3515
+ originalName.Name , originalFnType);
3501
3516
};
3502
3517
auto ambiguousDiagnostic = [&]() {
3503
- diagnose (originalName.Loc , diag::attr_ambiguous_reference_to_decl,
3504
- originalName.Name , attr->getAttrName ());
3518
+ diags. diagnose (originalName.Loc , diag::attr_ambiguous_reference_to_decl,
3519
+ originalName.Name , attr->getAttrName ());
3505
3520
};
3506
3521
auto notFunctionDiagnostic = [&]() {
3507
- diagnose (originalName.Loc , diag::autodiff_attr_original_decl_invalid_kind,
3508
- originalName.Name );
3522
+ diags.diagnose (originalName.Loc ,
3523
+ diag::autodiff_attr_original_decl_invalid_kind,
3524
+ originalName.Name );
3509
3525
};
3510
3526
std::function<void ()> invalidTypeContextDiagnostic = [&]() {
3511
- diagnose (originalName.Loc ,
3512
- diag::autodiff_attr_original_decl_not_same_type_context,
3513
- originalName.Name );
3527
+ diags. diagnose (originalName.Loc ,
3528
+ diag::autodiff_attr_original_decl_not_same_type_context,
3529
+ originalName.Name );
3514
3530
};
3515
3531
3516
3532
// Returns true if the derivative function and original function candidate are
@@ -3544,22 +3560,19 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3544
3560
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
3545
3561
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
3546
3562
hasValidTypeContext, invalidTypeContextDiagnostic);
3547
- if (!originalAFD) {
3548
- attr->setInvalid ();
3549
- return ;
3550
- }
3563
+ if (!originalAFD)
3564
+ return true ;
3551
3565
// Diagnose original stored properties. Stored properties cannot have custom
3552
3566
// registered derivatives.
3553
3567
if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
3554
3568
auto *asd = accessorDecl->getStorage ();
3555
3569
if (asd->hasStorage ()) {
3556
- diagnose (originalName.Loc ,
3557
- diag::derivative_attr_original_stored_property_unsupported,
3558
- originalName.Name );
3559
- diagnose (originalAFD->getLoc (), diag::decl_declared_here,
3560
- asd->getFullName ());
3561
- attr->setInvalid ();
3562
- return ;
3570
+ diags.diagnose (originalName.Loc ,
3571
+ diag::derivative_attr_original_stored_property_unsupported,
3572
+ originalName.Name );
3573
+ diags.diagnose (originalAFD->getLoc (), diag::decl_declared_here,
3574
+ asd->getFullName ());
3575
+ return true ;
3563
3576
}
3564
3577
}
3565
3578
attr->setOriginalFunction (originalAFD);
@@ -3576,19 +3589,15 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3576
3589
resolvedWrtParamIndices = computeDifferentiationParameters (
3577
3590
parsedWrtParams, derivative, derivative->getGenericEnvironment (),
3578
3591
attr->getAttrName (), attr->getLocation ());
3579
- if (!resolvedWrtParamIndices) {
3580
- attr->setInvalid ();
3581
- return ;
3582
- }
3592
+ if (!resolvedWrtParamIndices)
3593
+ return true ;
3583
3594
3584
3595
// Check if the `wrt:` parameter indices are valid.
3585
3596
if (checkDifferentiationParameters (
3586
3597
originalAFD, resolvedWrtParamIndices, originalFnType,
3587
3598
derivative->getGenericEnvironment (), derivative->getModuleContext (),
3588
- parsedWrtParams, attr->getLocation ())) {
3589
- attr->setInvalid ();
3590
- return ;
3591
- }
3599
+ parsedWrtParams, attr->getLocation ()))
3600
+ return true ;
3592
3601
3593
3602
// Set the resolved `wrt:` parameter indices in the attribute.
3594
3603
attr->setParameterIndices (resolvedWrtParamIndices);
@@ -3647,42 +3656,56 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3647
3656
// Check if differential/pullback type matches expected type.
3648
3657
if (!actualFuncEltType->isEqual (expectedFuncEltType)) {
3649
3658
// Emit differential/pullback type mismatch error on attribute.
3650
- diagnoseAndRemoveAttr (attr, diag::derivative_attr_result_func_type_mismatch,
3651
- funcResultElt.getName (), originalAFD->getFullName ());
3659
+ diags.diagnose (attr->getLocation (),
3660
+ diag::derivative_attr_result_func_type_mismatch,
3661
+ funcResultElt.getName (), originalAFD->getFullName ());
3652
3662
// Emit note with expected differential/pullback type on actual type
3653
3663
// location.
3654
3664
auto *tupleReturnTypeRepr =
3655
3665
cast<TupleTypeRepr>(derivative->getBodyResultTypeLoc ().getTypeRepr ());
3656
3666
auto *funcEltTypeRepr = tupleReturnTypeRepr->getElementType (1 );
3657
- diagnose (funcEltTypeRepr->getStartLoc (),
3658
- diag::derivative_attr_result_func_type_mismatch_note,
3659
- funcResultElt.getName (), expectedFuncEltType)
3667
+ diags
3668
+ .diagnose (funcEltTypeRepr->getStartLoc (),
3669
+ diag::derivative_attr_result_func_type_mismatch_note,
3670
+ funcResultElt.getName (), expectedFuncEltType)
3660
3671
.highlight (funcEltTypeRepr->getSourceRange ());
3661
3672
// Emit note showing original function location, if possible.
3662
3673
if (originalAFD->getLoc ().isValid ())
3663
- diagnose (originalAFD->getLoc (),
3664
- diag::derivative_attr_result_func_original_note,
3665
- originalAFD->getFullName ());
3666
- return ;
3674
+ diags. diagnose (originalAFD->getLoc (),
3675
+ diag::derivative_attr_result_func_original_note,
3676
+ originalAFD->getFullName ());
3677
+ return true ;
3667
3678
}
3668
3679
3669
3680
// Reject different-file derivative registration.
3670
3681
// TODO(TF-1021): Lift this restriction.
3671
3682
if (originalAFD->getParentSourceFile () != derivative->getParentSourceFile ()) {
3672
- diagnoseAndRemoveAttr (attr,
3673
- diag::derivative_attr_not_in_same_file_as_original);
3674
- return ;
3683
+ diags. diagnose (attr-> getLocation () ,
3684
+ diag::derivative_attr_not_in_same_file_as_original);
3685
+ return true ;
3675
3686
}
3676
3687
3677
3688
// Reject duplicate `@derivative` attributes.
3678
- auto insertion = Ctx.DerivativeAttrs .try_emplace (
3679
- std::make_tuple (originalAFD, resolvedWrtParamIndices, kind), attr);
3680
- if (!insertion.second ) {
3681
- diagnoseAndRemoveAttr (attr,
3682
- diag::derivative_attr_original_already_has_derivative,
3683
- originalAFD->getFullName ());
3684
- diagnose (insertion.first ->getSecond ()->getLocation (),
3685
- diag::derivative_attr_duplicate_note);
3686
- return ;
3689
+ auto &derivativeAttrs = Ctx.DerivativeAttrs [std::make_tuple (
3690
+ originalAFD, resolvedWrtParamIndices, kind)];
3691
+ derivativeAttrs.insert (attr);
3692
+ if (derivativeAttrs.size () > 1 ) {
3693
+ diags.diagnose (attr->getLocation (),
3694
+ diag::derivative_attr_original_already_has_derivative,
3695
+ originalAFD->getFullName ());
3696
+ for (auto *duplicateAttr : derivativeAttrs) {
3697
+ if (duplicateAttr == attr)
3698
+ continue ;
3699
+ diags.diagnose (duplicateAttr->getLocation (),
3700
+ diag::derivative_attr_duplicate_note);
3701
+ }
3702
+ return true ;
3687
3703
}
3704
+
3705
+ return false ;
3706
+ }
3707
+
3708
+ void AttributeChecker::visitDerivativeAttr (DerivativeAttr *attr) {
3709
+ if (typeCheckDerivativeAttr (Ctx, D, attr))
3710
+ attr->setInvalid ();
3688
3711
}
0 commit comments