@@ -3371,7 +3371,6 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3371
3371
derivative->getInterfaceType ()->castTo <AnyFunctionType>();
3372
3372
3373
3373
// Perform preliminary `@derivative` declaration checks.
3374
-
3375
3374
// The result type should be a two-element tuple.
3376
3375
// Either a value and pullback:
3377
3376
// (value: R, pullback: (R.TangentVector) -> (T.TangentVector...)
@@ -3381,28 +3380,25 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3381
3380
auto derivativeResultTupleType = derivativeResultType->getAs <TupleType>();
3382
3381
if (!derivativeResultTupleType ||
3383
3382
derivativeResultTupleType->getNumElements () != 2 ) {
3384
- diagnose (attr->getLocation (), diag::derivative_attr_expected_result_tuple);
3385
- attr->setInvalid ();
3383
+ diagnoseAndRemoveAttr (attr, diag::derivative_attr_expected_result_tuple);
3386
3384
return ;
3387
3385
}
3388
3386
auto valueResultElt = derivativeResultTupleType->getElement (0 );
3389
3387
auto funcResultElt = derivativeResultTupleType->getElement (1 );
3390
3388
// Get derivative kind and derivative function identifier.
3391
3389
AutoDiffDerivativeFunctionKind kind;
3392
3390
if (valueResultElt.getName ().str () != " value" ) {
3393
- diagnose (attr->getLocation (),
3394
- diag::derivative_attr_invalid_result_tuple_value_label);
3395
- attr->setInvalid ();
3391
+ diagnoseAndRemoveAttr (
3392
+ attr, diag::derivative_attr_invalid_result_tuple_value_label);
3396
3393
return ;
3397
3394
}
3398
3395
if (funcResultElt.getName ().str () == " differential" ) {
3399
3396
kind = AutoDiffDerivativeFunctionKind::JVP;
3400
3397
} else if (funcResultElt.getName ().str () == " pullback" ) {
3401
3398
kind = AutoDiffDerivativeFunctionKind::VJP;
3402
3399
} else {
3403
- diagnose (attr->getLocation (),
3404
- diag::derivative_attr_invalid_result_tuple_func_label);
3405
- attr->setInvalid ();
3400
+ diagnoseAndRemoveAttr (
3401
+ attr, diag::derivative_attr_invalid_result_tuple_func_label);
3406
3402
return ;
3407
3403
}
3408
3404
attr->setDerivativeKind (kind);
@@ -3414,10 +3410,9 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3414
3410
auto valueResultConf = TypeChecker::conformsToProtocol (
3415
3411
valueResultType, diffableProto, derivative->getDeclContext (), None);
3416
3412
if (!valueResultConf) {
3417
- diagnose (attr->getLocation (),
3418
- diag::derivative_attr_result_value_not_differentiable,
3419
- valueResultElt.getType ());
3420
- attr->setInvalid ();
3413
+ diagnoseAndRemoveAttr (attr,
3414
+ diag::derivative_attr_result_value_not_differentiable,
3415
+ valueResultElt.getType ());
3421
3416
return ;
3422
3417
}
3423
3418
@@ -3602,9 +3597,8 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3602
3597
// Check if differential/pullback type matches expected type.
3603
3598
if (!funcEltType->isEqual (expectedFuncEltType)) {
3604
3599
// Emit differential/pullback type mismatch error on attribute.
3605
- diagnose (attr->getLocation (),
3606
- diag::derivative_attr_result_func_type_mismatch,
3607
- funcResultElt.getName (), originalAFD->getFullName ());
3600
+ diagnoseAndRemoveAttr (attr, diag::derivative_attr_result_func_type_mismatch,
3601
+ funcResultElt.getName (), originalAFD->getFullName ());
3608
3602
// Emit note with expected differential/pullback type on actual type
3609
3603
// location.
3610
3604
auto *tupleReturnTypeRepr =
@@ -3619,7 +3613,6 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3619
3613
diagnose (originalAFD->getLoc (),
3620
3614
diag::derivative_attr_result_func_original_note,
3621
3615
originalAFD->getFullName ());
3622
- attr->setInvalid ();
3623
3616
return ;
3624
3617
}
3625
3618
0 commit comments